Wire up load progress
This doesn't expose a UX yet, but wires the initial server portion of progress reporting during load
This commit is contained in:
parent
38255d2af1
commit
b37b496a12
3 changed files with 61 additions and 8 deletions
14
llm/ext_server/server.cpp
vendored
14
llm/ext_server/server.cpp
vendored
|
@ -334,6 +334,7 @@ struct server_metrics {
|
|||
struct llama_server_context
|
||||
{
|
||||
llama_model *model = nullptr;
|
||||
float modelProgress = 0.0;
|
||||
llama_context *ctx = nullptr;
|
||||
|
||||
clip_ctx *clp_ctx = nullptr;
|
||||
|
@ -2779,6 +2780,12 @@ inline void signal_handler(int signal) {
|
|||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
static bool update_load_progress(float progress, void *data)
|
||||
{
|
||||
((llama_server_context*)data)->modelProgress = progress;
|
||||
return true;
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
char* wchar_to_char(const wchar_t* wstr) {
|
||||
if (wstr == nullptr) return nullptr;
|
||||
|
@ -2884,7 +2891,9 @@ int main(int argc, char **argv) {
|
|||
break;
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
res.set_content(R"({"status": "loading model"})", "application/json");
|
||||
char buf[128];
|
||||
snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress);
|
||||
res.set_content(buf, "application/json");
|
||||
res.status = 503; // HTTP Service Unavailable
|
||||
break;
|
||||
case SERVER_STATE_ERROR:
|
||||
|
@ -3079,6 +3088,9 @@ int main(int argc, char **argv) {
|
|||
});
|
||||
|
||||
// load the model
|
||||
params.progress_callback = update_load_progress;
|
||||
params.progress_callback_user_data = (void*)&llama;
|
||||
|
||||
if (!llama.load_model(params))
|
||||
{
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
|
|
31
llm/patches/01-load-progress.diff
Normal file
31
llm/patches/01-load-progress.diff
Normal file
|
@ -0,0 +1,31 @@
|
|||
diff --git a/common/common.cpp b/common/common.cpp
|
||||
index ba1ecf0e..cead57cc 100644
|
||||
--- a/common/common.cpp
|
||||
+++ b/common/common.cpp
|
||||
@@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
+ mparams.progress_callback = params.progress_callback;
|
||||
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
} else {
|
||||
diff --git a/common/common.h b/common/common.h
|
||||
index d80344f2..71e84834 100644
|
||||
--- a/common/common.h
|
||||
+++ b/common/common.h
|
||||
@@ -174,6 +174,13 @@ struct gpt_params {
|
||||
// multimodal models (see examples/llava)
|
||||
std::string mmproj = ""; // path to multimodal projector
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
+
|
||||
+ // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||
+ // If the provided progress_callback returns true, model loading continues.
|
||||
+ // If it returns false, model loading is immediately aborted.
|
||||
+ llama_progress_callback progress_callback = NULL;
|
||||
+ // context pointer passed to the progress callback
|
||||
+ void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
void gpt_params_handle_model_default(gpt_params & params);
|
|
@ -55,6 +55,7 @@ type llmServer struct {
|
|||
totalLayers uint64
|
||||
gpuCount int
|
||||
loadDuration time.Duration // Record how long it took the model to load
|
||||
loadProgress float32
|
||||
|
||||
sem *semaphore.Weighted
|
||||
}
|
||||
|
@ -425,10 +426,11 @@ func (s ServerStatus) ToString() string {
|
|||
}
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
|
@ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvailable, nil
|
||||
case "loading model":
|
||||
s.loadProgress = status.Progress
|
||||
return ServerStatusLoadingModel, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
|
@ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error {
|
|||
|
||||
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
stallDuration := 60 * time.Second
|
||||
stallTimer := time.Now().Add(stallDuration) // give up if we stall for
|
||||
|
||||
slog.Info("waiting for llama runner to start responding")
|
||||
var lastStatus ServerStatus = -1
|
||||
|
@ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||
default:
|
||||
}
|
||||
if time.Now().After(expiresAt) {
|
||||
if time.Now().After(stallTimer) {
|
||||
// timeout
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %s", msg)
|
||||
return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
|
||||
}
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
|
@ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||
defer cancel()
|
||||
priorProgress := s.loadProgress
|
||||
status, _ := s.getServerStatus(ctx)
|
||||
if lastStatus != status && status != ServerStatusReady {
|
||||
// Only log on status changes
|
||||
|
@ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||
return nil
|
||||
default:
|
||||
lastStatus = status
|
||||
// Reset the timer as long as we're making forward progress on the load
|
||||
if priorProgress != s.loadProgress {
|
||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 250)
|
||||
continue
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue