From b37b496a12ebad0105ed17826d838346bff6e5ef Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 20 May 2024 16:41:43 -0700 Subject: [PATCH] Wire up load progress This doesn't expose a UX yet, but wires the initial server portion of progress reporting during load --- llm/ext_server/server.cpp | 14 +++++++++++++- llm/patches/01-load-progress.diff | 31 +++++++++++++++++++++++++++++++ llm/server.go | 24 +++++++++++++++++------- 3 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 llm/patches/01-load-progress.diff diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 3e03bb34..e342d5f1 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -334,6 +334,7 @@ struct server_metrics { struct llama_server_context { llama_model *model = nullptr; + float modelProgress = 0.0; llama_context *ctx = nullptr; clip_ctx *clp_ctx = nullptr; @@ -2779,6 +2780,12 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } +static bool update_load_progress(float progress, void *data) +{ + ((llama_server_context*)data)->modelProgress = progress; + return true; +} + #if defined(_WIN32) char* wchar_to_char(const wchar_t* wstr) { if (wstr == nullptr) return nullptr; @@ -2884,7 +2891,9 @@ int main(int argc, char **argv) { break; } case SERVER_STATE_LOADING_MODEL: - res.set_content(R"({"status": "loading model"})", "application/json"); + char buf[128]; + snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress); + res.set_content(buf, "application/json"); res.status = 503; // HTTP Service Unavailable break; case SERVER_STATE_ERROR: @@ -3079,6 +3088,9 @@ int main(int argc, char **argv) { }); // load the model + params.progress_callback = update_load_progress; + params.progress_callback_user_data = (void*)&llama; + if (!llama.load_model(params)) { state.store(SERVER_STATE_ERROR); diff --git a/llm/patches/01-load-progress.diff b/llm/patches/01-load-progress.diff new file mode 100644 index 00000000..acd44d20 --- /dev/null +++ b/llm/patches/01-load-progress.diff @@ -0,0 +1,31 @@ +diff --git a/common/common.cpp b/common/common.cpp +index ba1ecf0e..cead57cc 100644 +--- a/common/common.cpp ++++ b/common/common.cpp +@@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; ++ mparams.progress_callback = params.progress_callback; ++ mparams.progress_callback_user_data = params.progress_callback_user_data; + if (params.kv_overrides.empty()) { + mparams.kv_overrides = NULL; + } else { +diff --git a/common/common.h b/common/common.h +index d80344f2..71e84834 100644 +--- a/common/common.h ++++ b/common/common.h +@@ -174,6 +174,13 @@ struct gpt_params { + // multimodal models (see examples/llava) + std::string mmproj = ""; // path to multimodal projector + std::vector image; // path to image file(s) ++ ++ // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. ++ // If the provided progress_callback returns true, model loading continues. ++ // If it returns false, model loading is immediately aborted. ++ llama_progress_callback progress_callback = NULL; ++ // context pointer passed to the progress callback ++ void * progress_callback_user_data; + }; + + void gpt_params_handle_model_default(gpt_params & params); diff --git a/llm/server.go b/llm/server.go index c63a76a4..384d31ca 100644 --- a/llm/server.go +++ b/llm/server.go @@ -55,6 +55,7 @@ type llmServer struct { totalLayers uint64 gpuCount int loadDuration time.Duration // Record how long it took the model to load + loadProgress float32 sem *semaphore.Weighted } @@ -425,10 +426,11 @@ func (s ServerStatus) ToString() string { } type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` + Status string `json:"status"` + SlotsIdle int `json:"slots_idle"` + SlotsProcessing int `json:"slots_processing"` + Error string `json:"error"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { case "no slot available": return ServerStatusNoSlotsAvailable, nil case "loading model": + s.loadProgress = status.Progress return ServerStatusLoadingModel, nil default: return ServerStatusError, fmt.Errorf("server error: %+v", status) @@ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error { start := time.Now() - expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load + stallDuration := 60 * time.Second + stallTimer := time.Now().Add(stallDuration) // give up if we stall for slog.Info("waiting for llama runner to start responding") var lastStatus ServerStatus = -1 @@ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) default: } - if time.Now().After(expiresAt) { + if time.Now().After(stallTimer) { // timeout msg := "" if s.status != nil && s.status.LastErrMsg != "" { msg = s.status.LastErrMsg } - return fmt.Errorf("timed out waiting for llama runner to start: %s", msg) + return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg) } if s.cmd.ProcessState != nil { msg := "" @@ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { } ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() + priorProgress := s.loadProgress status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes @@ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { return nil default: lastStatus = status + // Reset the timer as long as we're making forward progress on the load + if priorProgress != s.loadProgress { + slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) + stallTimer = time.Now().Add(stallDuration) + } time.Sleep(time.Millisecond * 250) continue }