Wire up load progress

This doesn't expose a UX yet, but wires the initial server portion
of progress reporting during load
This commit is contained in:
Daniel Hiltgen 2024-05-20 16:41:43 -07:00
parent 38255d2af1
commit b37b496a12
3 changed files with 61 additions and 8 deletions

View file

@ -334,6 +334,7 @@ struct server_metrics {
struct llama_server_context struct llama_server_context
{ {
llama_model *model = nullptr; llama_model *model = nullptr;
float modelProgress = 0.0;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
clip_ctx *clp_ctx = nullptr; clip_ctx *clp_ctx = nullptr;
@ -2779,6 +2780,12 @@ inline void signal_handler(int signal) {
shutdown_handler(signal); shutdown_handler(signal);
} }
static bool update_load_progress(float progress, void *data)
{
((llama_server_context*)data)->modelProgress = progress;
return true;
}
#if defined(_WIN32) #if defined(_WIN32)
char* wchar_to_char(const wchar_t* wstr) { char* wchar_to_char(const wchar_t* wstr) {
if (wstr == nullptr) return nullptr; if (wstr == nullptr) return nullptr;
@ -2884,7 +2891,9 @@ int main(int argc, char **argv) {
break; break;
} }
case SERVER_STATE_LOADING_MODEL: case SERVER_STATE_LOADING_MODEL:
res.set_content(R"({"status": "loading model"})", "application/json"); char buf[128];
snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress);
res.set_content(buf, "application/json");
res.status = 503; // HTTP Service Unavailable res.status = 503; // HTTP Service Unavailable
break; break;
case SERVER_STATE_ERROR: case SERVER_STATE_ERROR:
@ -3079,6 +3088,9 @@ int main(int argc, char **argv) {
}); });
// load the model // load the model
params.progress_callback = update_load_progress;
params.progress_callback_user_data = (void*)&llama;
if (!llama.load_model(params)) if (!llama.load_model(params))
{ {
state.store(SERVER_STATE_ERROR); state.store(SERVER_STATE_ERROR);

View file

@ -0,0 +1,31 @@
diff --git a/common/common.cpp b/common/common.cpp
index ba1ecf0e..cead57cc 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
+ mparams.progress_callback = params.progress_callback;
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
diff --git a/common/common.h b/common/common.h
index d80344f2..71e84834 100644
--- a/common/common.h
+++ b/common/common.h
@@ -174,6 +174,13 @@ struct gpt_params {
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
std::vector<std::string> image; // path to image file(s)
+
+ // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
+ // If the provided progress_callback returns true, model loading continues.
+ // If it returns false, model loading is immediately aborted.
+ llama_progress_callback progress_callback = NULL;
+ // context pointer passed to the progress callback
+ void * progress_callback_user_data;
};
void gpt_params_handle_model_default(gpt_params & params);

View file

@ -55,6 +55,7 @@ type llmServer struct {
totalLayers uint64 totalLayers uint64
gpuCount int gpuCount int
loadDuration time.Duration // Record how long it took the model to load loadDuration time.Duration // Record how long it took the model to load
loadProgress float32
sem *semaphore.Weighted sem *semaphore.Weighted
} }
@ -425,10 +426,11 @@ func (s ServerStatus) ToString() string {
} }
type ServerStatusResp struct { type ServerStatusResp struct {
Status string `json:"status"` Status string `json:"status"`
SlotsIdle int `json:"slots_idle"` SlotsIdle int `json:"slots_idle"`
SlotsProcessing int `json:"slots_processing"` SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"` Error string `json:"error"`
Progress float32 `json:"progress"`
} }
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
case "no slot available": case "no slot available":
return ServerStatusNoSlotsAvailable, nil return ServerStatusNoSlotsAvailable, nil
case "loading model": case "loading model":
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil return ServerStatusLoadingModel, nil
default: default:
return ServerStatusError, fmt.Errorf("server error: %+v", status) return ServerStatusError, fmt.Errorf("server error: %+v", status)
@ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load stallDuration := 60 * time.Second
stallTimer := time.Now().Add(stallDuration) // give up if we stall for
slog.Info("waiting for llama runner to start responding") slog.Info("waiting for llama runner to start responding")
var lastStatus ServerStatus = -1 var lastStatus ServerStatus = -1
@ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default: default:
} }
if time.Now().After(expiresAt) { if time.Now().After(stallTimer) {
// timeout // timeout
msg := "" msg := ""
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
return fmt.Errorf("timed out waiting for llama runner to start: %s", msg) return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
} }
if s.cmd.ProcessState != nil { if s.cmd.ProcessState != nil {
msg := "" msg := ""
@ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
} }
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel() defer cancel()
priorProgress := s.loadProgress
status, _ := s.getServerStatus(ctx) status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady { if lastStatus != status && status != ServerStatusReady {
// Only log on status changes // Only log on status changes
@ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
return nil return nil
default: default:
lastStatus = status lastStatus = status
// Reset the timer as long as we're making forward progress on the load
if priorProgress != s.loadProgress {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration)
}
time.Sleep(time.Millisecond * 250) time.Sleep(time.Millisecond * 250)
continue continue
} }