embed text document in modelfile

This commit is contained in:
Bruce MacDonald 2023-08-04 18:56:40 -04:00
parent 34a13a9d05
commit a6f6d18f83
8 changed files with 330 additions and 59 deletions

View file

@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop() spinner.Stop()
} }
currentDigest = resp.Digest currentDigest = resp.Digest
switch {
case strings.Contains(resp.Status, "embeddings"):
bar = progressbar.Default(int64(resp.Total), resp.Status)
bar.Set(resp.Completed)
default:
// pulling
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
int64(resp.Total), int64(resp.Total),
fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Status,
) )
bar.Set(resp.Completed) bar.Set(resp.Completed)
}
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else { } else {

1
go.mod
View file

@ -42,6 +42,7 @@ require (
golang.org/x/sys v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0 golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
gonum.org/v1/gonum v0.13.0
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

2
go.sum
View file

@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View file

@ -85,6 +85,7 @@ llama_token llama_sample(
} }
*/ */
import "C" import "C"
import ( import (
"bytes" "bytes"
"embed" "embed"
@ -93,6 +94,7 @@ import (
"io" "io"
"log" "log"
"os" "os"
"reflect"
"strings" "strings"
"sync" "sync"
"unicode/utf8" "unicode/utf8"
@ -414,3 +416,38 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil return token, nil
} }
func (llm *LLM) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
tokens := llm.tokenize(input)
if tokens == nil {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
if n <= 0 {
return nil, errors.New("llama: no embeddings generated")
}
embedPtr := C.llama_get_embeddings(llm.ctx)
if embedPtr == nil {
return nil, errors.New("llama: embedding retrieval failed")
}
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
}

View file

@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(fields[1]) command.Args = string(fields[1])
// copy command for validation // copy command for validation
modelCommand = command modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
command.Name = string(bytes.ToLower(fields[0])) command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1]) command.Args = string(fields[1])
case "PARAMETER": case "PARAMETER":

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
@ -9,6 +10,7 @@ import (
"html/template" "html/template"
"io" "io"
"log" "log"
"math"
"net/http" "net/http"
"os" "os"
"path" "path"
@ -18,7 +20,10 @@ import (
"strings" "strings"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
"gonum.org/v1/gonum/mat"
) )
type RegistryOptions struct { type RegistryOptions struct {
@ -34,6 +39,7 @@ type Model struct {
System string System string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Embeddings []vector.Embedding
} }
func (m *Model) Prompt(request api.GenerateRequest) (string, error) { func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
@ -51,6 +57,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
First bool First bool
System string System string
Prompt string Prompt string
Embed string
// deprecated: versions <= 0.0.7 used this to omit the system prompt // deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int Context []int
@ -65,6 +72,21 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
vars.System = request.System vars.System = request.System
} }
if len(m.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(request.Prompt)
if err != nil {
return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
toEmbed := ""
for _, e := range embed {
toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
}
vars.Embed = toEmbed
}
var sb strings.Builder var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil { if err := tmpl.Execute(&sb, vars); err != nil {
return "", err return "", err
@ -157,6 +179,16 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
case "application/vnd.ollama.image.embed":
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open file: %s", filename)
}
defer file.Close()
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
return nil, err
}
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -195,6 +227,26 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func filenameWithPath(path, f string) (string, error) {
// if filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(f, "~/") {
parts := strings.Split(f, "/")
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to open file: %v", err)
}
f = filepath.Join(home, filepath.Join(parts[1:]...))
}
// if filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(f) {
f = filepath.Join(filepath.Dir(path), f)
}
return f, nil
}
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
mf, err := os.Open(path) mf, err := os.Open(path)
if err != nil { if err != nil {
@ -211,52 +263,37 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mf, err := GetManifest(ParseModelPath(c.Args)) mf, err := GetManifest(ParseModelPath(c.Args))
if err != nil { if err != nil {
fp := c.Args modelFile, err := filenameWithPath(path, c.Args)
// If filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(fp, "~/") {
parts := strings.Split(fp, "/")
home, err := os.UserHomeDir()
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return err
} }
if _, err := os.Stat(modelFile); err != nil {
fp = filepath.Join(home, filepath.Join(parts[1:]...))
}
// If filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(fp) {
fp = filepath.Join(filepath.Dir(path), fp)
}
if _, err := os.Stat(fp); err != nil {
// the model file does not exist, try pulling it // the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"}) fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, err = GetManifest(ParseModelPath(c.Args)) mf, err = GetManifest(ParseModelPath(modelFile))
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err) return fmt.Errorf("failed to open file after pull: %v", err)
} }
} else { } else {
return err return err
} }
} else { } else {
// create a model from this specified file // create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"}) fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
file, err := os.Open(fp)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return fmt.Errorf("failed to open file: %v", err)
} }
@ -280,19 +317,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
layers = append(layers, newLayer) layers = append(layers, newLayer)
} }
} }
case "license": case "embed":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) // TODO: support entire directories here
// remove the prompt layer if one exists embedFilePath, err := filenameWithPath(path, c.Args)
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
return err return err
} }
embed.files = append(embed.files, embedFilePath)
layer.MediaType = mediaType case "license", "template", "system", "prompt":
layers = append(layers, layer)
case "template", "system", "prompt":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists // remove the prompt layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@ -315,18 +347,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameter layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
paramData, err := paramsToReader(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return fmt.Errorf("couldn't create params json: %v", err)
} }
l, err := CreateLayer(paramData)
bts, err := json.Marshal(formattedParams)
if err != nil {
return err
}
l, err := CreateLayer(bytes.NewReader(bts))
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return fmt.Errorf("failed to create layer: %v", err)
} }
l.MediaType = "application/vnd.ollama.image.params" l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l) layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions()
embed.opts.FromMap(formattedParams)
} }
// generate the embedding layers
embeddingLayers, err := embeddingLayers(embed)
if err != nil {
return err
}
layers = append(layers, embeddingLayers...)
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
if err != nil { if err != nil {
return err return err
@ -361,6 +410,112 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
return nil return nil
} }
type EmbeddingParams struct {
model string
opts api.Options
files []string // paths to files to embed
fn func(resp api.ProgressResponse)
}
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{}
if len(e.files) > 0 {
model, err := GetModel(e.model)
if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
}
e.opts.EmbeddingOnly = true
llm, err := llama.New(model.ModelPath, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
for _, filePath := range e.files {
// TODO: check if txt file type
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("could not open embed file: %w", err)
}
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanLines)
data := []string{}
for scanner.Scan() {
data = append(data, scanner.Text())
}
f.Close()
// the digest of the file is set here so that the client knows a new operation is in progress
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
embeddings := []vector.Embedding{}
for i, d := range data {
if strings.TrimSpace(d) == "" {
continue
}
e.fn(api.ProgressResponse{
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
Digest: fileDigest,
Total: len(data) - 1,
Completed: i,
})
retry := 0
generate:
if retry > 3 {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
}
embed, err := llm.Embedding(d)
if err != nil {
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
retry++
goto generate
}
// Check for NaN and Inf in the embedding, which can't be stored
for _, value := range embed {
if math.IsNaN(value) || math.IsInf(value, 0) {
log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding
llm, err = llama.New(model.ModelPath, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
retry++
goto generate
}
}
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
}
b, err := json.Marshal(embeddings)
if err != nil {
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
}
r := bytes.NewReader(b)
digest, size := GetSHA256Digest(r)
// Reset the position of the reader after calculating the digest
if _, err := r.Seek(0, 0); err != nil {
return nil, fmt.Errorf("could not reset embed reader: %w", err)
}
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.ollama.image.embed",
Digest: digest,
Size: size,
},
Reader: r,
}
layers = append(layers, layer)
}
}
return layers, nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
j := 0 j := 0
for _, l := range layers { for _, l := range layers {
@ -449,8 +604,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json // formatParams converts specified parameter options to their correct types
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{} opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@ -504,12 +659,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
} }
} }
bts, err := json.Marshal(out) return out, nil
if err != nil {
return nil, err
}
return bytes.NewReader(bts), nil
} }
func getLayerDigests(layers []*LayerReader) ([]string, error) { func getLayerDigests(layers []*LayerReader) ([]string, error) {
@ -1042,7 +1192,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
for { for {
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", digest), Status: fmt.Sprintf("pulling %s...", digest[7:19]),
Digest: digest, Digest: digest,
Total: int(total), Total: int(total),
Completed: int(completed), Completed: int(completed),

View file

@ -20,12 +20,14 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/vector"
) )
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
llm *llama.LLM llm *llama.LLM
Embeddings []vector.Embedding
expireAt time.Time expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
@ -72,6 +74,11 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = "" loaded.digest = ""
} }
if model.Embeddings != nil && len(model.Embeddings) > 0 {
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
loaded.Embeddings = model.Embeddings
}
llm, err := llama.New(model.ModelPath, opts) llm, err := llama.New(model.ModelPath, opts)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -82,7 +89,6 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = model.Digest loaded.digest = model.Digest
loaded.options = opts loaded.options = opts
} }
sessionDuration := 5 * time.Minute sessionDuration := 5 * time.Minute
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)

69
vector/store.go Normal file
View file

@ -0,0 +1,69 @@
package vector
import (
"container/heap"
"sort"
"gonum.org/v1/gonum/mat"
)
type Embedding struct {
Vector []float64 // the embedding vector
Data string // the data represted by the embedding
}
type EmbeddingSimilarity struct {
Embedding Embedding // the embedding that was used to calculate the similarity
Similarity float64 // the similarity between the embedding and the query
}
type Heap []EmbeddingSimilarity
func (h Heap) Len() int { return len(h) }
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *Heap) Push(e any) {
*h = append(*h, e.(EmbeddingSimilarity))
}
func (h *Heap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
// This value will range from -1 to 1, where 1 means the vectors are identical.
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
dotProduct := mat.Dot(vec1, vec2)
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
if norms == 0 {
return 0
}
return dotProduct / norms
}
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
h := &Heap{}
heap.Init(h)
for _, emb := range embeddings {
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
if h.Len() > k {
heap.Pop(h)
}
}
topK := make([]EmbeddingSimilarity, 0, h.Len())
for h.Len() > 0 {
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
}
sort.Slice(topK, func(i, j int) bool {
return topK[i].Similarity > topK[j].Similarity
})
return topK
}