fix embeddings invalid values
This commit is contained in:
parent
9738ef85db
commit
984c9c628c
2 changed files with 9 additions and 39 deletions
|
@ -94,7 +94,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
|
||||||
return nil, errors.New("llama: tokenize embedding")
|
return nil, errors.New("llama: tokenize embedding")
|
||||||
}
|
}
|
||||||
|
|
||||||
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
|
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
||||||
if retval != 0 {
|
if retval != 0 {
|
||||||
return nil, errors.New("llama: eval")
|
return nil, errors.New("llama: eval")
|
||||||
}
|
}
|
||||||
|
|
||||||
n := int(C.llama_n_embd(llm.ctx))
|
n := C.llama_n_embd(llm.ctx)
|
||||||
if n <= 0 {
|
if n <= 0 {
|
||||||
return nil, errors.New("llama: no embeddings generated")
|
return nil, errors.New("llama: no embeddings generated")
|
||||||
}
|
}
|
||||||
|
cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n)
|
||||||
|
|
||||||
embedPtr := C.llama_get_embeddings(llm.ctx)
|
embeddings := make([]float64, len(cEmbeddings))
|
||||||
if embedPtr == nil {
|
for i, v := range cEmbeddings {
|
||||||
return nil, errors.New("llama: embedding retrieval failed")
|
embeddings[i] = float64(v)
|
||||||
}
|
}
|
||||||
|
return embeddings, nil
|
||||||
header := reflect.SliceHeader{
|
|
||||||
Data: uintptr(unsafe.Pointer(embedPtr)),
|
|
||||||
Len: n,
|
|
||||||
Cap: n,
|
|
||||||
}
|
|
||||||
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
|
|
||||||
|
|
||||||
return embedSlice, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -480,31 +479,10 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
Total: len(data) - 1,
|
Total: len(data) - 1,
|
||||||
Completed: i,
|
Completed: i,
|
||||||
})
|
})
|
||||||
retry := 0
|
|
||||||
generate:
|
|
||||||
if retry > 3 {
|
|
||||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
embed, err := llm.Embedding(d)
|
embed, err := llm.Embedding(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
|
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||||
retry++
|
continue
|
||||||
goto generate
|
|
||||||
}
|
|
||||||
// Check for NaN and Inf in the embedding, which can't be stored
|
|
||||||
for _, value := range embed {
|
|
||||||
if math.IsNaN(value) || math.IsInf(value, 0) {
|
|
||||||
log.Printf("reloading model, embedding contains NaN or Inf")
|
|
||||||
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
|
|
||||||
llm.Close()
|
|
||||||
llm, err = llama.New(e.model, e.opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
|
||||||
}
|
|
||||||
retry++
|
|
||||||
goto generate
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue