Merge pull request #4294 from dhiltgen/harden_subprocess_reaping

Harden subprocess reaping
This commit is contained in:
Daniel Hiltgen 2024-05-09 14:02:16 -07:00 committed by GitHub
commit d0425f26cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 59 deletions

View file

@ -2727,7 +2727,7 @@ static json format_detokenized_response(std::string content)
static void log_server_request(const httplib::Request &req, const httplib::Response &res) static void log_server_request(const httplib::Request &req, const httplib::Response &res)
{ {
// skip GH copilot requests when using default port // skip GH copilot requests when using default port
if (req.path == "/v1/health" || req.path == "/v1/completions") if (req.path == "/health" || req.path == "/v1/health" || req.path == "/v1/completions")
{ {
return; return;
} }
@ -3054,6 +3054,26 @@ int main(int argc, char **argv) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
} }
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
state.store(SERVER_STATE_ERROR);
return 1;
}
return 0;
});
// load the model // load the model
if (!llama.load_model(params)) if (!llama.load_model(params))
{ {
@ -3258,26 +3278,6 @@ int main(int argc, char **argv) {
}*/ }*/
//); //);
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
state.store(SERVER_STATE_ERROR);
return 1;
}
return 0;
});
llama.queue_tasks.on_new_task(std::bind( llama.queue_tasks.on_new_task(std::bind(
&llama_server_context::process_single_task, &llama, std::placeholders::_1)); &llama_server_context::process_single_task, &llama, std::placeholders::_1));
llama.queue_tasks.on_finish_multitask(std::bind( llama.queue_tasks.on_finish_multitask(std::bind(

View file

@ -53,6 +53,7 @@ type llmServer struct {
estimatedTotal uint64 // Total size of model estimatedTotal uint64 // Total size of model
totalLayers uint64 totalLayers uint64
gpuCount int gpuCount int
loadDuration time.Duration // Record how long it took the model to load
sem *semaphore.Weighted sem *semaphore.Weighted
} }
@ -291,6 +292,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
sem: semaphore.NewWeighted(int64(numParallel)), sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1, totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount, gpuCount: gpuCount,
done: make(chan error, 1),
} }
s.cmd.Env = os.Environ() s.cmd.Env = os.Environ()
@ -339,6 +341,11 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
continue continue
} }
// reap subprocess when it exits
go func() {
s.done <- s.cmd.Wait()
}()
return s, nil return s, nil
} }
@ -483,13 +490,11 @@ func (s *llmServer) Ping(ctx context.Context) error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
slog.Info("waiting for llama runner to start responding") slog.Info("waiting for llama runner to start responding")
var lastStatus ServerStatus = -1 var lastStatus ServerStatus = -1
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -501,7 +506,8 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
case <-ticker.C: default:
}
if time.Now().After(expiresAt) { if time.Now().After(expiresAt) {
// timeout // timeout
msg := "" msg := ""
@ -517,25 +523,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
} }
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel() defer cancel()
status, err := s.getServerStatus(c) status, _ := s.getServerStatus(ctx)
if err != nil && lastStatus != status { if lastStatus != status && status != ServerStatusReady {
slog.Debug("server not yet available", "error", err) // Only log on status changes
lastStatus = status slog.Info("waiting for server to become available", "status", status.ToString())
continue
} }
switch status { switch status {
case ServerStatusLoadingModel:
// TODO - this state never seems to happen with the current server.cpp code (bug?)
// it doesn't respond to the health endpoint until after the model is loaded
slog.Debug("loading model")
case ServerStatusReady: case ServerStatusReady:
slog.Debug(fmt.Sprintf("llama runner started in %f seconds", time.Since(start).Seconds())) s.loadDuration = time.Since(start)
slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", s.loadDuration.Seconds()))
return nil return nil
} default:
lastStatus = status
time.Sleep(time.Millisecond * 250)
continue
} }
} }
} }
@ -943,8 +946,11 @@ func (s *llmServer) Close() error {
if err := s.cmd.Process.Kill(); err != nil { if err := s.cmd.Process.Kill(); err != nil {
return err return err
} }
// if ProcessState is already populated, Wait already completed, no need to wait again
_ = s.cmd.Wait() if s.cmd.ProcessState == nil {
slog.Debug("waiting for llama server to exit")
<-s.done
}
slog.Debug("llama server stopped") slog.Debug("llama server stopped")
} }