From 984c9c628cc990183e45b27dddf2d38537264ad3 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 9 Aug 2023 16:13:24 -0400 Subject: [PATCH] fix embeddings invalid values --- llama/llama.go | 22 +++++++--------------- server/images.go | 26 ++------------------------ 2 files changed, 9 insertions(+), 39 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index f8c897d4..aba6c513 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -94,7 +94,6 @@ import ( "io" "log" "os" - "reflect" "strings" "sync" "unicode/utf8" @@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) { return nil, errors.New("llama: tokenize embedding") } - retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)) + retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread)) if retval != 0 { return nil, errors.New("llama: eval") } - n := int(C.llama_n_embd(llm.ctx)) + n := C.llama_n_embd(llm.ctx) if n <= 0 { return nil, errors.New("llama: no embeddings generated") } + cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n) - embedPtr := C.llama_get_embeddings(llm.ctx) - if embedPtr == nil { - return nil, errors.New("llama: embedding retrieval failed") + embeddings := make([]float64, len(cEmbeddings)) + for i, v := range cEmbeddings { + embeddings[i] = float64(v) } - - header := reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(embedPtr)), - Len: n, - Cap: n, - } - embedSlice := *(*[]float64)(unsafe.Pointer(&header)) - - return embedSlice, nil + return embeddings, nil } diff --git a/server/images.go b/server/images.go index 22b0a74e..2ec24854 100644 --- a/server/images.go +++ b/server/images.go @@ -11,7 +11,6 @@ import ( "html/template" "io" "log" - "math" "net/http" "os" "path/filepath" @@ -480,31 +479,10 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { Total: len(data) - 1, Completed: i, }) - retry := 0 - generate: - if retry > 3 { - log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) - continue - } embed, err := llm.Embedding(d) if err != nil { - log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) - retry++ - goto generate - } - // Check for NaN and Inf in the embedding, which can't be stored - for _, value := range embed { - if math.IsNaN(value) || math.IsInf(value, 0) { - log.Printf("reloading model, embedding contains NaN or Inf") - // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it - llm.Close() - llm, err = llama.New(e.model, e.opts) - if err != nil { - return nil, fmt.Errorf("load model to generate embeddings: %v", err) - } - retry++ - goto generate - } + log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) + continue } embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) }