Wire up load progress
This doesn't expose a UX yet, but wires the initial server portion of progress reporting during load
This commit is contained in:
parent
38255d2af1
commit
b37b496a12
3 changed files with 61 additions and 8 deletions
14
llm/ext_server/server.cpp
vendored
14
llm/ext_server/server.cpp
vendored
|
@ -334,6 +334,7 @@ struct server_metrics {
|
||||||
struct llama_server_context
|
struct llama_server_context
|
||||||
{
|
{
|
||||||
llama_model *model = nullptr;
|
llama_model *model = nullptr;
|
||||||
|
float modelProgress = 0.0;
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
|
|
||||||
clip_ctx *clp_ctx = nullptr;
|
clip_ctx *clp_ctx = nullptr;
|
||||||
|
@ -2779,6 +2780,12 @@ inline void signal_handler(int signal) {
|
||||||
shutdown_handler(signal);
|
shutdown_handler(signal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool update_load_progress(float progress, void *data)
|
||||||
|
{
|
||||||
|
((llama_server_context*)data)->modelProgress = progress;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
char* wchar_to_char(const wchar_t* wstr) {
|
char* wchar_to_char(const wchar_t* wstr) {
|
||||||
if (wstr == nullptr) return nullptr;
|
if (wstr == nullptr) return nullptr;
|
||||||
|
@ -2884,7 +2891,9 @@ int main(int argc, char **argv) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SERVER_STATE_LOADING_MODEL:
|
case SERVER_STATE_LOADING_MODEL:
|
||||||
res.set_content(R"({"status": "loading model"})", "application/json");
|
char buf[128];
|
||||||
|
snprintf(&buf[0], 128, R"({"status": "loading model", "progress": %0.2f})", llama.modelProgress);
|
||||||
|
res.set_content(buf, "application/json");
|
||||||
res.status = 503; // HTTP Service Unavailable
|
res.status = 503; // HTTP Service Unavailable
|
||||||
break;
|
break;
|
||||||
case SERVER_STATE_ERROR:
|
case SERVER_STATE_ERROR:
|
||||||
|
@ -3079,6 +3088,9 @@ int main(int argc, char **argv) {
|
||||||
});
|
});
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
|
params.progress_callback = update_load_progress;
|
||||||
|
params.progress_callback_user_data = (void*)&llama;
|
||||||
|
|
||||||
if (!llama.load_model(params))
|
if (!llama.load_model(params))
|
||||||
{
|
{
|
||||||
state.store(SERVER_STATE_ERROR);
|
state.store(SERVER_STATE_ERROR);
|
||||||
|
|
31
llm/patches/01-load-progress.diff
Normal file
31
llm/patches/01-load-progress.diff
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
diff --git a/common/common.cpp b/common/common.cpp
|
||||||
|
index ba1ecf0e..cead57cc 100644
|
||||||
|
--- a/common/common.cpp
|
||||||
|
+++ b/common/common.cpp
|
||||||
|
@@ -1836,6 +1836,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||||
|
mparams.use_mmap = params.use_mmap;
|
||||||
|
mparams.use_mlock = params.use_mlock;
|
||||||
|
mparams.check_tensors = params.check_tensors;
|
||||||
|
+ mparams.progress_callback = params.progress_callback;
|
||||||
|
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
|
||||||
|
if (params.kv_overrides.empty()) {
|
||||||
|
mparams.kv_overrides = NULL;
|
||||||
|
} else {
|
||||||
|
diff --git a/common/common.h b/common/common.h
|
||||||
|
index d80344f2..71e84834 100644
|
||||||
|
--- a/common/common.h
|
||||||
|
+++ b/common/common.h
|
||||||
|
@@ -174,6 +174,13 @@ struct gpt_params {
|
||||||
|
// multimodal models (see examples/llava)
|
||||||
|
std::string mmproj = ""; // path to multimodal projector
|
||||||
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
+
|
||||||
|
+ // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||||
|
+ // If the provided progress_callback returns true, model loading continues.
|
||||||
|
+ // If it returns false, model loading is immediately aborted.
|
||||||
|
+ llama_progress_callback progress_callback = NULL;
|
||||||
|
+ // context pointer passed to the progress callback
|
||||||
|
+ void * progress_callback_user_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
void gpt_params_handle_model_default(gpt_params & params);
|
|
@ -55,6 +55,7 @@ type llmServer struct {
|
||||||
totalLayers uint64
|
totalLayers uint64
|
||||||
gpuCount int
|
gpuCount int
|
||||||
loadDuration time.Duration // Record how long it took the model to load
|
loadDuration time.Duration // Record how long it took the model to load
|
||||||
|
loadProgress float32
|
||||||
|
|
||||||
sem *semaphore.Weighted
|
sem *semaphore.Weighted
|
||||||
}
|
}
|
||||||
|
@ -425,10 +426,11 @@ func (s ServerStatus) ToString() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerStatusResp struct {
|
type ServerStatusResp struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SlotsIdle int `json:"slots_idle"`
|
SlotsIdle int `json:"slots_idle"`
|
||||||
SlotsProcessing int `json:"slots_processing"`
|
SlotsProcessing int `json:"slots_processing"`
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
|
Progress float32 `json:"progress"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
|
@ -476,6 +478,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
case "no slot available":
|
case "no slot available":
|
||||||
return ServerStatusNoSlotsAvailable, nil
|
return ServerStatusNoSlotsAvailable, nil
|
||||||
case "loading model":
|
case "loading model":
|
||||||
|
s.loadProgress = status.Progress
|
||||||
return ServerStatusLoadingModel, nil
|
return ServerStatusLoadingModel, nil
|
||||||
default:
|
default:
|
||||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||||
|
@ -516,7 +519,8 @@ func (s *llmServer) Ping(ctx context.Context) error {
|
||||||
|
|
||||||
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
stallDuration := 60 * time.Second
|
||||||
|
stallTimer := time.Now().Add(stallDuration) // give up if we stall for
|
||||||
|
|
||||||
slog.Info("waiting for llama runner to start responding")
|
slog.Info("waiting for llama runner to start responding")
|
||||||
var lastStatus ServerStatus = -1
|
var lastStatus ServerStatus = -1
|
||||||
|
@ -534,13 +538,13 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if time.Now().After(expiresAt) {
|
if time.Now().After(stallTimer) {
|
||||||
// timeout
|
// timeout
|
||||||
msg := ""
|
msg := ""
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
return fmt.Errorf("timed out waiting for llama runner to start: %s", msg)
|
return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
|
||||||
}
|
}
|
||||||
if s.cmd.ProcessState != nil {
|
if s.cmd.ProcessState != nil {
|
||||||
msg := ""
|
msg := ""
|
||||||
|
@ -551,6 +555,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
|
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
priorProgress := s.loadProgress
|
||||||
status, _ := s.getServerStatus(ctx)
|
status, _ := s.getServerStatus(ctx)
|
||||||
if lastStatus != status && status != ServerStatusReady {
|
if lastStatus != status && status != ServerStatusReady {
|
||||||
// Only log on status changes
|
// Only log on status changes
|
||||||
|
@ -563,6 +568,11 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
lastStatus = status
|
lastStatus = status
|
||||||
|
// Reset the timer as long as we're making forward progress on the load
|
||||||
|
if priorProgress != s.loadProgress {
|
||||||
|
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||||
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
|
}
|
||||||
time.Sleep(time.Millisecond * 250)
|
time.Sleep(time.Millisecond * 250)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue