pr comments

- default to embeddings enabled
- move embedding logic for loaded model to request
- allow embedding full directory
- close llm on reload
This commit is contained in:
Bruce MacDonald 2023-08-08 13:49:37 -04:00
parent a6f6d18f83
commit 21ddcaa1f1
3 changed files with 97 additions and 82 deletions

View file

@ -275,6 +275,7 @@ func DefaultOptions() Options {
UseMLock: false,
RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0,
EmbeddingOnly: true,
RepeatLastN: 64,
RepeatPenalty: 1.1,

View file

@ -23,7 +23,6 @@ import (
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
"gonum.org/v1/gonum/mat"
)
type RegistryOptions struct {
@ -42,7 +41,7 @@ type Model struct {
Embeddings []vector.Embedding
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
@ -67,26 +66,12 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
vars.System = m.System
vars.Prompt = request.Prompt
vars.Context = request.Context
vars.Embed = embedding
if request.System != "" {
vars.System = request.System
}
if len(m.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(request.Prompt)
if err != nil {
return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
toEmbed := ""
for _, e := range embed {
toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
}
vars.Embed = toEmbed
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
@ -432,8 +417,19 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
for _, filePath := range e.files {
// TODO: check if txt file type
addedFiles := make(map[string]bool) // keep track of files that have already been added
for _, filePattern := range e.files {
matchingFiles, err := filepath.Glob(filePattern)
if err != nil {
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
}
for _, filePath := range matchingFiles {
if addedFiles[filePath] {
continue
}
addedFiles[filePath] = true
// TODO: check file type
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("could not open embed file: %w", err)
@ -477,7 +473,8 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
for _, value := range embed {
if math.IsNaN(value) || math.IsInf(value, 0) {
log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
llm.Close()
llm, err = llama.New(model.ModelPath, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
@ -497,7 +494,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
digest, size := GetSHA256Digest(r)
// Reset the position of the reader after calculating the digest
if _, err := r.Seek(0, 0); err != nil {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("could not reset embed reader: %w", err)
}
@ -513,6 +510,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers = append(layers, layer)
}
}
}
return layers, nil
}

View file

@ -17,6 +17,7 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
@ -114,7 +115,22 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt, err := model.Prompt(req)
embedding := ""
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
for _, e := range topK {
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
}
}
prompt, err := model.Prompt(req, embedding)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return