fix embeddings invalid values

This commit is contained in:
Bruce MacDonald 2023-08-10 10:17:00 -04:00 committed by GitHub
commit 8e1234b758
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 39 deletions

View file

@ -94,7 +94,6 @@ import (
"io" "io"
"log" "log"
"os" "os"
"reflect"
"strings" "strings"
"sync" "sync"
"unicode/utf8" "unicode/utf8"
@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding") return nil, errors.New("llama: tokenize embedding")
} }
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)) retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 { if retval != 0 {
return nil, errors.New("llama: eval") return nil, errors.New("llama: eval")
} }
n := int(C.llama_n_embd(llm.ctx)) n := C.llama_n_embd(llm.ctx)
if n <= 0 { if n <= 0 {
return nil, errors.New("llama: no embeddings generated") return nil, errors.New("llama: no embeddings generated")
} }
cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n)
embedPtr := C.llama_get_embeddings(llm.ctx) embeddings := make([]float64, len(cEmbeddings))
if embedPtr == nil { for i, v := range cEmbeddings {
return nil, errors.New("llama: embedding retrieval failed") embeddings[i] = float64(v)
} }
return embeddings, nil
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
} }

View file

@ -11,7 +11,6 @@ import (
"html/template" "html/template"
"io" "io"
"log" "log"
"math"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -480,31 +479,10 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
Total: len(data) - 1, Total: len(data) - 1,
Completed: i, Completed: i,
}) })
retry := 0
generate:
if retry > 3 {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
}
embed, err := llm.Embedding(d) embed, err := llm.Embedding(d)
if err != nil { if err != nil {
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
retry++ continue
goto generate
}
// Check for NaN and Inf in the embedding, which can't be stored
for _, value := range embed {
if math.IsNaN(value) || math.IsInf(value, 0) {
log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
llm.Close()
llm, err = llama.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
retry++
goto generate
}
} }
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
} }