Shutdown faster

Make sure that when a shutdown signal comes, we shutdown quickly instead
of waiting for a potentially long exchange to wrap up.
This commit is contained in:
Daniel Hiltgen 2024-02-08 22:22:50 -08:00
parent 42b797ed9c
commit 6680761596
2 changed files with 44 additions and 1 deletions

View file

@ -258,7 +258,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
}) })
} }
if p.Stop { if p.Stop || bool(result.stop) {
fn(PredictResult{ fn(PredictResult{
Done: true, Done: true,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,

View file

@ -1,4 +1,5 @@
#include "ext_server.h" #include "ext_server.h"
#include <atomic>
// Necessary evil since the server types are not defined in a header // Necessary evil since the server types are not defined in a header
#include "server.cpp" #include "server.cpp"
@ -27,8 +28,24 @@
// Expose the llama server as a callable extern "C" API // Expose the llama server as a callable extern "C" API
llama_server_context *llama = NULL; llama_server_context *llama = NULL;
std::thread ext_server_thread; std::thread ext_server_thread;
bool shutting_down = false;
std::atomic_int recv_counter;
// RAII wrapper for tracking in-flight recv calls
class atomicRecv {
public:
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
++this->atomic;
}
~atomicRecv() {
--this->atomic;
}
private:
std::atomic<int> &atomic;
};
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
recv_counter = 0;
assert(err != NULL && sparams != NULL); assert(err != NULL && sparams != NULL);
log_set_target(stderr); log_set_target(stderr);
if (!sparams->verbose_logging) { if (!sparams->verbose_logging) {
@ -151,7 +168,14 @@ void llama_server_start() {
void llama_server_stop() { void llama_server_stop() {
assert(llama != NULL); assert(llama != NULL);
// Shutdown any in-flight requests and block incoming requests.
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n"); LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
shutting_down = true;
while (recv_counter.load() > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
// This may take a while for any pending tasks to drain // This may take a while for any pending tasks to drain
// TODO - consider a timeout to cancel tasks if it's taking too long // TODO - consider a timeout to cancel tasks if it's taking too long
llama->queue_tasks.terminate(); llama->queue_tasks.terminate();
@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
resp->id = -1; resp->id = -1;
resp->msg[0] = '\0'; resp->msg[0] = '\0';
try { try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
json data = json::parse(json_req); json data = json::parse(json_req);
resp->id = llama->queue_tasks.get_new_id(); resp->id = llama->queue_tasks.get_new_id();
llama->queue_results.add_waiting_task_id(resp->id); llama->queue_results.add_waiting_task_id(resp->id);
@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id,
resp->json_resp = NULL; resp->json_resp = NULL;
std::string result_json; std::string result_json;
try { try {
atomicRecv ar(recv_counter);
task_result result = llama->queue_results.recv(task_id); task_result result = llama->queue_results.recv(task_id);
result_json = result_json =
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id,
llama->request_cancel(task_id); llama->request_cancel(task_id);
LOG_TEE("next result removing waiting task ID: %d\n", task_id); LOG_TEE("next result removing waiting task ID: %d\n", task_id);
llama->queue_results.remove_waiting_task_id(task_id); llama->queue_results.remove_waiting_task_id(task_id);
} else if (shutting_down) {
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
llama->request_cancel(task_id);
llama->queue_results.remove_waiting_task_id(task_id);
resp->stop = true;
} }
} catch (std::exception &e) { } catch (std::exception &e) {
resp->error = true; resp->error = true;
@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
err->id = 0; err->id = 0;
err->msg[0] = '\0'; err->msg[0] = '\0';
try { try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req); const json body = json::parse(json_req);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) { if (body.count("content") != 0) {
@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
err->id = 0; err->id = 0;
err->msg[0] = '\0'; err->msg[0] = '\0';
try { try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req); const json body = json::parse(json_req);
std::string content; std::string content;
if (body.count("tokens") != 0) { if (body.count("tokens") != 0) {
@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp,
err->id = 0; err->id = 0;
err->msg[0] = '\0'; err->msg[0] = '\0';
try { try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req); const json body = json::parse(json_req);
json prompt; json prompt;
if (body.count("content") != 0) { if (body.count("content") != 0) {
@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp,
const int task_id = llama->queue_tasks.get_new_id(); const int task_id = llama->queue_tasks.get_new_id();
llama->queue_results.add_waiting_task_id(task_id); llama->queue_results.add_waiting_task_id(task_id);
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
atomicRecv ar(recv_counter);
task_result result = llama->queue_results.recv(task_id); task_result result = llama->queue_results.recv(task_id);
std::string result_json = result.result_json.dump(); std::string result_json = result.result_json.dump();
const std::string::size_type size = result_json.size() + 1; const std::string::size_type size = result_json.size() + 1;