Merge pull request #2422 from dhiltgen/better_kill
More robust shutdown
This commit is contained in:
commit
939c60473f
2 changed files with 44 additions and 1 deletions
|
@ -258,7 +258,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Stop {
|
if p.Stop || bool(result.stop) {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
Done: true,
|
Done: true,
|
||||||
PromptEvalCount: p.Timings.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "ext_server.h"
|
#include "ext_server.h"
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
// Necessary evil since the server types are not defined in a header
|
// Necessary evil since the server types are not defined in a header
|
||||||
#include "server.cpp"
|
#include "server.cpp"
|
||||||
|
@ -27,8 +28,24 @@
|
||||||
// Expose the llama server as a callable extern "C" API
|
// Expose the llama server as a callable extern "C" API
|
||||||
llama_server_context *llama = NULL;
|
llama_server_context *llama = NULL;
|
||||||
std::thread ext_server_thread;
|
std::thread ext_server_thread;
|
||||||
|
bool shutting_down = false;
|
||||||
|
std::atomic_int recv_counter;
|
||||||
|
|
||||||
|
// RAII wrapper for tracking in-flight recv calls
|
||||||
|
class atomicRecv {
|
||||||
|
public:
|
||||||
|
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
|
||||||
|
++this->atomic;
|
||||||
|
}
|
||||||
|
~atomicRecv() {
|
||||||
|
--this->atomic;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::atomic<int> &atomic;
|
||||||
|
};
|
||||||
|
|
||||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||||
|
recv_counter = 0;
|
||||||
assert(err != NULL && sparams != NULL);
|
assert(err != NULL && sparams != NULL);
|
||||||
log_set_target(stderr);
|
log_set_target(stderr);
|
||||||
if (!sparams->verbose_logging) {
|
if (!sparams->verbose_logging) {
|
||||||
|
@ -151,7 +168,14 @@ void llama_server_start() {
|
||||||
|
|
||||||
void llama_server_stop() {
|
void llama_server_stop() {
|
||||||
assert(llama != NULL);
|
assert(llama != NULL);
|
||||||
|
// Shutdown any in-flight requests and block incoming requests.
|
||||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||||
|
shutting_down = true;
|
||||||
|
|
||||||
|
while (recv_counter.load() > 0) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||||
|
}
|
||||||
|
|
||||||
// This may take a while for any pending tasks to drain
|
// This may take a while for any pending tasks to drain
|
||||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||||
llama->queue_tasks.terminate();
|
llama->queue_tasks.terminate();
|
||||||
|
@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||||
resp->id = -1;
|
resp->id = -1;
|
||||||
resp->msg[0] = '\0';
|
resp->msg[0] = '\0';
|
||||||
try {
|
try {
|
||||||
|
if (shutting_down) {
|
||||||
|
throw std::runtime_error("server shutting down");
|
||||||
|
}
|
||||||
json data = json::parse(json_req);
|
json data = json::parse(json_req);
|
||||||
resp->id = llama->queue_tasks.get_new_id();
|
resp->id = llama->queue_tasks.get_new_id();
|
||||||
llama->queue_results.add_waiting_task_id(resp->id);
|
llama->queue_results.add_waiting_task_id(resp->id);
|
||||||
|
@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id,
|
||||||
resp->json_resp = NULL;
|
resp->json_resp = NULL;
|
||||||
std::string result_json;
|
std::string result_json;
|
||||||
try {
|
try {
|
||||||
|
atomicRecv ar(recv_counter);
|
||||||
task_result result = llama->queue_results.recv(task_id);
|
task_result result = llama->queue_results.recv(task_id);
|
||||||
result_json =
|
result_json =
|
||||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||||
|
@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id,
|
||||||
llama->request_cancel(task_id);
|
llama->request_cancel(task_id);
|
||||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||||
llama->queue_results.remove_waiting_task_id(task_id);
|
llama->queue_results.remove_waiting_task_id(task_id);
|
||||||
|
} else if (shutting_down) {
|
||||||
|
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
|
||||||
|
llama->request_cancel(task_id);
|
||||||
|
llama->queue_results.remove_waiting_task_id(task_id);
|
||||||
|
resp->stop = true;
|
||||||
}
|
}
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
resp->error = true;
|
resp->error = true;
|
||||||
|
@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||||
err->id = 0;
|
err->id = 0;
|
||||||
err->msg[0] = '\0';
|
err->msg[0] = '\0';
|
||||||
try {
|
try {
|
||||||
|
if (shutting_down) {
|
||||||
|
throw std::runtime_error("server shutting down");
|
||||||
|
}
|
||||||
const json body = json::parse(json_req);
|
const json body = json::parse(json_req);
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
if (body.count("content") != 0) {
|
if (body.count("content") != 0) {
|
||||||
|
@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||||
err->id = 0;
|
err->id = 0;
|
||||||
err->msg[0] = '\0';
|
err->msg[0] = '\0';
|
||||||
try {
|
try {
|
||||||
|
if (shutting_down) {
|
||||||
|
throw std::runtime_error("server shutting down");
|
||||||
|
}
|
||||||
const json body = json::parse(json_req);
|
const json body = json::parse(json_req);
|
||||||
std::string content;
|
std::string content;
|
||||||
if (body.count("tokens") != 0) {
|
if (body.count("tokens") != 0) {
|
||||||
|
@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
||||||
err->id = 0;
|
err->id = 0;
|
||||||
err->msg[0] = '\0';
|
err->msg[0] = '\0';
|
||||||
try {
|
try {
|
||||||
|
if (shutting_down) {
|
||||||
|
throw std::runtime_error("server shutting down");
|
||||||
|
}
|
||||||
const json body = json::parse(json_req);
|
const json body = json::parse(json_req);
|
||||||
json prompt;
|
json prompt;
|
||||||
if (body.count("content") != 0) {
|
if (body.count("content") != 0) {
|
||||||
|
@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
||||||
const int task_id = llama->queue_tasks.get_new_id();
|
const int task_id = llama->queue_tasks.get_new_id();
|
||||||
llama->queue_results.add_waiting_task_id(task_id);
|
llama->queue_results.add_waiting_task_id(task_id);
|
||||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||||
|
atomicRecv ar(recv_counter);
|
||||||
task_result result = llama->queue_results.recv(task_id);
|
task_result result = llama->queue_results.recv(task_id);
|
||||||
std::string result_json = result.result_json.dump();
|
std::string result_json = result.result_json.dump();
|
||||||
const std::string::size_type size = result_json.size() + 1;
|
const std::string::size_type size = result_json.size() + 1;
|
||||||
|
|
Loading…
Reference in a new issue