From de1557a0dcdcd1cfd5b5af00a17042c3dba97ffd Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 22 Oct 2024 14:57:46 -0700 Subject: [PATCH] runner.go: Better handle return NULL values from llama.cpp Llama.cpp sometimes returns NULL as a return value to report an error. We should explicitly check for this and convert it to a Go error rather than putting NULL in our data structures and waiting for it to blow up later. --- llama/llama.go | 29 ++++++++++++++++++++--------- llama/runner/runner.go | 11 +++++++++-- llm/server.go | 10 ++++++++-- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index ca17e38c..f7c0f362 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -136,10 +136,6 @@ func (c *Context) Model() *Model { return &Model{c: C.llama_get_model(c.c)} } -func (c *Context) GetLogitsIth(i int) []float32 { - return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab()) -} - func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) { C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta)) } @@ -163,7 +159,12 @@ func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { } func (c *Context) GetEmbeddingsIth(i int) []float32 { - return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd()) + embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i))) + if embeddings == nil { + return nil + } + + return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) } type ModelParams struct { @@ -184,7 +185,7 @@ func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool { return true } -func LoadModelFromFile(modelPath string, params ModelParams) *Model { +func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) { cparams := C.llama_model_default_params() cparams.n_gpu_layers = C.int(params.NumGpuLayers) cparams.main_gpu = C.int32_t(params.MainGpu) @@ -214,18 +215,28 @@ func LoadModelFromFile(modelPath string, params ModelParams) *Model { cparams.progress_callback_user_data = unsafe.Pointer(&handle) } - return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)} + m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)} + if m.c == (*C.struct_llama_model)(C.NULL) { + return nil, fmt.Errorf("unable to load model: %s", modelPath) + } + + return &m, nil } func FreeModel(model *Model) { C.llama_free_model(model.c) } -func NewContextWithModel(model *Model, params ContextParams) *Context { - return &Context{ +func NewContextWithModel(model *Model, params ContextParams) (*Context, error) { + c := Context{ c: C.llama_new_context_with_model(model.c, params.c), numThreads: int(params.c.n_threads), } + if c.c == (*C.struct_llama_context)(C.NULL) { + return nil, errors.New("unable to create llama context") + } + + return &c, nil } func (m *Model) NumVocab() int { diff --git a/llama/runner/runner.go b/llama/runner/runner.go index b35704b5..f472d076 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -790,10 +790,17 @@ func (s *Server) loadModel( ) { llama.BackendInit() - s.model = llama.LoadModelFromFile(mpath, params) + var err error + s.model, err = llama.LoadModelFromFile(mpath, params) + if err != nil { + panic(err) + } ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention) - s.lc = llama.NewContextWithModel(s.model, ctxParams) + s.lc, err = llama.NewContextWithModel(s.model, ctxParams) + if err != nil { + panic(err) + } if lpath != "" { err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads) diff --git a/llm/server.go b/llm/server.go index cc4eac90..a4c99dd9 100644 --- a/llm/server.go +++ b/llm/server.go @@ -958,7 +958,10 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) if resp.StatusCode == http.StatusNotFound { if s.model == nil { slog.Debug("new runner detected, loading model for cgo tokenization") - m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) + m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) + if err != nil { + return nil, err + } s.model = m } return s.model.Tokenize(content, false, true) @@ -1027,7 +1030,10 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error if resp.StatusCode == http.StatusNotFound { if s.model == nil { slog.Debug("new runner detected, loading model for cgo tokenization") - m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) + m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) + if err != nil { + return "", err + } s.model = m } var resp string