runner.go: Better handle return NULL values from llama.cpp
Llama.cpp sometimes returns NULL as a return value to report an error. We should explicitly check for this and convert it to a Go error rather than putting NULL in our data structures and waiting for it to blow up later.
This commit is contained in:
parent
084929c293
commit
de1557a0dc
3 changed files with 37 additions and 13 deletions
|
@ -136,10 +136,6 @@ func (c *Context) Model() *Model {
|
|||
return &Model{c: C.llama_get_model(c.c)}
|
||||
}
|
||||
|
||||
func (c *Context) GetLogitsIth(i int) []float32 {
|
||||
return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
|
||||
C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
|
||||
}
|
||||
|
@ -163,7 +159,12 @@ func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
|||
}
|
||||
|
||||
func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd())
|
||||
embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
|
||||
if embeddings == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
|
@ -184,7 +185,7 @@ func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func LoadModelFromFile(modelPath string, params ModelParams) *Model {
|
||||
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
cparams := C.llama_model_default_params()
|
||||
cparams.n_gpu_layers = C.int(params.NumGpuLayers)
|
||||
cparams.main_gpu = C.int32_t(params.MainGpu)
|
||||
|
@ -214,18 +215,28 @@ func LoadModelFromFile(modelPath string, params ModelParams) *Model {
|
|||
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
|
||||
}
|
||||
|
||||
return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
|
||||
m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
|
||||
if m.c == (*C.struct_llama_model)(C.NULL) {
|
||||
return nil, fmt.Errorf("unable to load model: %s", modelPath)
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func FreeModel(model *Model) {
|
||||
C.llama_free_model(model.c)
|
||||
}
|
||||
|
||||
func NewContextWithModel(model *Model, params ContextParams) *Context {
|
||||
return &Context{
|
||||
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
|
||||
c := Context{
|
||||
c: C.llama_new_context_with_model(model.c, params.c),
|
||||
numThreads: int(params.c.n_threads),
|
||||
}
|
||||
if c.c == (*C.struct_llama_context)(C.NULL) {
|
||||
return nil, errors.New("unable to create llama context")
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (m *Model) NumVocab() int {
|
||||
|
|
|
@ -790,10 +790,17 @@ func (s *Server) loadModel(
|
|||
) {
|
||||
llama.BackendInit()
|
||||
|
||||
s.model = llama.LoadModelFromFile(mpath, params)
|
||||
var err error
|
||||
s.model, err = llama.LoadModelFromFile(mpath, params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
||||
s.lc = llama.NewContextWithModel(s.model, ctxParams)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if lpath != "" {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
|
||||
|
|
|
@ -958,7 +958,10 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
|
|||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.model = m
|
||||
}
|
||||
return s.model.Tokenize(content, false, true)
|
||||
|
@ -1027,7 +1030,10 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
|||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.model = m
|
||||
}
|
||||
var resp string
|
||||
|
|
Loading…
Add table
Reference in a new issue