Merge pull request #2422 from dhiltgen/better_kill
More robust shutdown
This commit is contained in:
commit
939c60473f
2 changed files with 44 additions and 1 deletions
|
@ -258,7 +258,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
|||
})
|
||||
}
|
||||
|
||||
if p.Stop {
|
||||
if p.Stop || bool(result.stop) {
|
||||
fn(PredictResult{
|
||||
Done: true,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "ext_server.h"
|
||||
#include <atomic>
|
||||
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
@ -27,8 +28,24 @@
|
|||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::thread ext_server_thread;
|
||||
bool shutting_down = false;
|
||||
std::atomic_int recv_counter;
|
||||
|
||||
// RAII wrapper for tracking in-flight recv calls
|
||||
class atomicRecv {
|
||||
public:
|
||||
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
|
||||
++this->atomic;
|
||||
}
|
||||
~atomicRecv() {
|
||||
--this->atomic;
|
||||
}
|
||||
private:
|
||||
std::atomic<int> &atomic;
|
||||
};
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
recv_counter = 0;
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
|
@ -151,7 +168,14 @@ void llama_server_start() {
|
|||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// Shutdown any in-flight requests and block incoming requests.
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
shutting_down = true;
|
||||
|
||||
while (recv_counter.load() > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
}
|
||||
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
|
@ -166,6 +190,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
|||
resp->id = -1;
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
|
@ -187,6 +214,7 @@ void llama_server_completion_next_result(const int task_id,
|
|||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
|
@ -203,6 +231,11 @@ void llama_server_completion_next_result(const int task_id,
|
|||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (shutting_down) {
|
||||
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
resp->stop = true;
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
|
@ -251,6 +284,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
|
|||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0) {
|
||||
|
@ -284,6 +320,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
|
|||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
|
@ -311,6 +350,9 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
|||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
json prompt;
|
||||
if (body.count("content") != 0) {
|
||||
|
@ -321,6 +363,7 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
|||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
|
|
Loading…
Reference in a new issue