From 6680761596cbd832619ba5a295f03b74c6500743 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 8 Feb 2024 22:22:50 -0800 Subject: [PATCH] Shutdown faster Make sure that when a shutdown signal comes, we shutdown quickly instead of waiting for a potentially long exchange to wrap up. --- llm/dyn_ext_server.go | 2 +- llm/ext_server/ext_server.cpp | 43 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index f7e19a7b..45b8da12 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -258,7 +258,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu }) } - if p.Stop { + if p.Stop || bool(result.stop) { fn(PredictResult{ Done: true, PromptEvalCount: p.Timings.PromptN, diff --git a/llm/ext_server/ext_server.cpp b/llm/ext_server/ext_server.cpp index b59b46d2..daba4e65 100644 --- a/llm/ext_server/ext_server.cpp +++ b/llm/ext_server/ext_server.cpp @@ -1,4 +1,5 @@ #include "ext_server.h" +#include // Necessary evil since the server types are not defined in a header #include "server.cpp" @@ -27,8 +28,24 @@ // Expose the llama server as a callable extern "C" API llama_server_context *llama = NULL; std::thread ext_server_thread; +bool shutting_down = false; +std::atomic_int recv_counter; +// RAII wrapper for tracking in-flight recv calls +class atomicRecv { + public: + atomicRecv(std::atomic &atomic) : atomic(atomic) { + ++this->atomic; + } + ~atomicRecv() { + --this->atomic; + } + private: + std::atomic &atomic; +}; + void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { + recv_counter = 0; assert(err != NULL && sparams != NULL); log_set_target(stderr); if (!sparams->verbose_logging) { @@ -151,7 +168,14 @@ void llama_server_start() { void llama_server_stop() { assert(llama != NULL); + // Shutdown any in-flight requests and block incoming requests. LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n"); + shutting_down = true; + + while (recv_counter.load() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + // This may take a while for any pending tasks to drain // TODO - consider a timeout to cancel tasks if it's taking too long llama->queue_tasks.terminate(); @@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) { resp->id = -1; resp->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } json data = json::parse(json_req); resp->id = llama->queue_tasks.get_new_id(); llama->queue_results.add_waiting_task_id(resp->id); @@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id, resp->json_resp = NULL; std::string result_json; try { + atomicRecv ar(recv_counter); task_result result = llama->queue_results.recv(task_id); result_json = result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); @@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id, llama->request_cancel(task_id); LOG_TEE("next result removing waiting task ID: %d\n", task_id); llama->queue_results.remove_waiting_task_id(task_id); + } else if (shutting_down) { + LOG_TEE("aborting completion due to shutdown %d\n", task_id); + llama->request_cancel(task_id); + llama->queue_results.remove_waiting_task_id(task_id); + resp->stop = true; } } catch (std::exception &e) { resp->error = true; @@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); std::vector tokens; if (body.count("content") != 0) { @@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); std::string content; if (body.count("tokens") != 0) { @@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); json prompt; if (body.count("content") != 0) { @@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp, const int task_id = llama->queue_tasks.get_new_id(); llama->queue_results.add_waiting_task_id(task_id); llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); + atomicRecv ar(recv_counter); task_result result = llama->queue_results.recv(task_id); std::string result_json = result.result_json.dump(); const std::string::size_type size = result_json.size() + 1;