diff --git a/cmd/cmd.go b/cmd/cmd.go index 9526c864..1e0d1f59 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner.Stop() } currentDigest = resp.Digest - bar = progressbar.DefaultBytes( - int64(resp.Total), - fmt.Sprintf("pulling %s...", resp.Digest[7:19]), - ) - - bar.Set(resp.Completed) + switch { + case strings.Contains(resp.Status, "embeddings"): + bar = progressbar.Default(int64(resp.Total), resp.Status) + bar.Set(resp.Completed) + default: + // pulling + bar = progressbar.DefaultBytes( + int64(resp.Total), + resp.Status, + ) + bar.Set(resp.Completed) + } } else if resp.Digest == currentDigest && resp.Digest != "" { bar.Set(resp.Completed) } else { diff --git a/go.mod b/go.mod index 554473cb..a0583e65 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect + gonum.org/v1/gonum v0.13.0 google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c4097bdb..7ec060d3 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM= +gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/llama/llama.go b/llama/llama.go index 0a523321..2c11cbc3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -85,6 +85,7 @@ llama_token llama_sample( } */ import "C" + import ( "bytes" "embed" @@ -93,6 +94,7 @@ import ( "io" "log" "os" + "reflect" "strings" "sync" "unicode/utf8" @@ -414,3 +416,38 @@ func (llm *LLM) next() (C.llama_token, error) { return token, nil } + +func (llm *LLM) Embedding(input string) ([]float64, error) { + if !llm.EmbeddingOnly { + return nil, errors.New("llama: embedding not enabled") + } + + tokens := llm.tokenize(input) + if tokens == nil { + return nil, errors.New("llama: tokenize embedding") + } + + retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)) + if retval != 0 { + return nil, errors.New("llama: eval") + } + + n := int(C.llama_n_embd(llm.ctx)) + if n <= 0 { + return nil, errors.New("llama: no embeddings generated") + } + + embedPtr := C.llama_get_embeddings(llm.ctx) + if embedPtr == nil { + return nil, errors.New("llama: embedding retrieval failed") + } + + header := reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(embedPtr)), + Len: n, + Cap: n, + } + embedSlice := *(*[]float64)(unsafe.Pointer(&header)) + + return embedSlice, nil +} diff --git a/parser/parser.go b/parser/parser.go index c89b13e6..06ccf786 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) { command.Args = string(fields[1]) // copy command for validation modelCommand = command - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": + case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED": command.Name = string(bytes.ToLower(fields[0])) command.Args = string(fields[1]) case "PARAMETER": diff --git a/server/images.go b/server/images.go index e06a40a1..4dcaf37d 100644 --- a/server/images.go +++ b/server/images.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "bytes" "crypto/sha256" "encoding/json" @@ -9,6 +10,7 @@ import ( "html/template" "io" "log" + "math" "net/http" "os" "path" @@ -18,7 +20,10 @@ import ( "strings" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/parser" + "github.com/jmorganca/ollama/vector" + "gonum.org/v1/gonum/mat" ) type RegistryOptions struct { @@ -28,12 +33,13 @@ type RegistryOptions struct { } type Model struct { - Name string `json:"name"` - ModelPath string - Template string - System string - Digest string - Options map[string]interface{} + Name string `json:"name"` + ModelPath string + Template string + System string + Digest string + Options map[string]interface{} + Embeddings []vector.Embedding } func (m *Model) Prompt(request api.GenerateRequest) (string, error) { @@ -51,6 +57,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { First bool System string Prompt string + Embed string // deprecated: versions <= 0.0.7 used this to omit the system prompt Context []int @@ -65,6 +72,21 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) { vars.System = request.System } + if len(m.Embeddings) > 0 { + promptEmbed, err := loaded.llm.Embedding(request.Prompt) + if err != nil { + return "", fmt.Errorf("failed to get embedding for prompt: %v", err) + } + // TODO: set embed_top from specified parameters in modelfile + embed_top := 3 + embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings) + toEmbed := "" + for _, e := range embed { + toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data) + } + vars.Embed = toEmbed + } + var sb strings.Builder if err := tmpl.Execute(&sb, vars); err != nil { return "", err @@ -157,6 +179,16 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename + case "application/vnd.ollama.image.embed": + file, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("failed to open file: %s", filename) + } + defer file.Close() + + if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil { + return nil, err + } case "application/vnd.ollama.image.template": bts, err := os.ReadFile(filename) if err != nil { @@ -195,6 +227,26 @@ func GetModel(name string) (*Model, error) { return model, nil } +func filenameWithPath(path, f string) (string, error) { + // if filePath starts with ~/, replace it with the user's home directory. + if strings.HasPrefix(f, "~/") { + parts := strings.Split(f, "/") + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to open file: %v", err) + } + + f = filepath.Join(home, filepath.Join(parts[1:]...)) + } + + // if filePath is not an absolute path, make it relative to the modelfile path + if !filepath.IsAbs(f) { + f = filepath.Join(filepath.Dir(path), f) + } + + return f, nil +} + func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { mf, err := os.Open(path) if err != nil { @@ -211,52 +263,37 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e var layers []*LayerReader params := make(map[string][]string) - + embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} for _, c := range commands { log.Printf("[%s] - %s\n", c.Name, c.Args) switch c.Name { case "model": fn(api.ProgressResponse{Status: "looking for model"}) + embed.model = c.Args mf, err := GetManifest(ParseModelPath(c.Args)) if err != nil { - fp := c.Args - - // If filePath starts with ~/, replace it with the user's home directory. - if strings.HasPrefix(fp, "~/") { - parts := strings.Split(fp, "/") - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - - fp = filepath.Join(home, filepath.Join(parts[1:]...)) + modelFile, err := filenameWithPath(path, c.Args) + if err != nil { + return err } - - // If filePath is not an absolute path, make it relative to the modelfile path - if !filepath.IsAbs(fp) { - fp = filepath.Join(filepath.Dir(path), fp) - } - - if _, err := os.Stat(fp); err != nil { + if _, err := os.Stat(modelFile); err != nil { // the model file does not exist, try pulling it if errors.Is(err, os.ErrNotExist) { fn(api.ProgressResponse{Status: "pulling model file"}) if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(ParseModelPath(c.Args)) + mf, err = GetManifest(ParseModelPath(modelFile)) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } - } else { return err } } else { // create a model from this specified file fn(api.ProgressResponse{Status: "creating model layer"}) - - file, err := os.Open(fp) + file, err := os.Open(modelFile) if err != nil { return fmt.Errorf("failed to open file: %v", err) } @@ -280,19 +317,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e layers = append(layers, newLayer) } } - case "license": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the prompt layer if one exists - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) - - layer, err := CreateLayer(strings.NewReader(c.Args)) + case "embed": + // TODO: support entire directories here + embedFilePath, err := filenameWithPath(path, c.Args) if err != nil { return err } - - layer.MediaType = mediaType - layers = append(layers, layer) - case "template", "system", "prompt": + embed.files = append(embed.files, embedFilePath) + case "license", "template", "system", "prompt": fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) // remove the prompt layer if one exists mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) @@ -315,18 +347,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameter layer"}) layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") - paramData, err := paramsToReader(params) + formattedParams, err := formatParams(params) if err != nil { return fmt.Errorf("couldn't create params json: %v", err) } - l, err := CreateLayer(paramData) + + bts, err := json.Marshal(formattedParams) + if err != nil { + return err + } + + l, err := CreateLayer(bytes.NewReader(bts)) if err != nil { return fmt.Errorf("failed to create layer: %v", err) } l.MediaType = "application/vnd.ollama.image.params" layers = append(layers, l) + + // apply these parameters to the embedding options, in case embeddings need to be generated using this model + embed.opts = api.DefaultOptions() + embed.opts.FromMap(formattedParams) } + // generate the embedding layers + embeddingLayers, err := embeddingLayers(embed) + if err != nil { + return err + } + layers = append(layers, embeddingLayers...) + digests, err := getLayerDigests(layers) if err != nil { return err @@ -361,6 +410,112 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e return nil } +type EmbeddingParams struct { + model string + opts api.Options + files []string // paths to files to embed + fn func(resp api.ProgressResponse) +} + +// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file +func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { + layers := []*LayerReader{} + if len(e.files) > 0 { + model, err := GetModel(e.model) + if err != nil { + return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + } + + e.opts.EmbeddingOnly = true + llm, err := llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + + for _, filePath := range e.files { + // TODO: check if txt file type + f, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("could not open embed file: %w", err) + } + scanner := bufio.NewScanner(f) + scanner.Split(bufio.ScanLines) + + data := []string{} + for scanner.Scan() { + data = append(data, scanner.Text()) + } + f.Close() + + // the digest of the file is set here so that the client knows a new operation is in progress + fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath))) + + embeddings := []vector.Embedding{} + for i, d := range data { + if strings.TrimSpace(d) == "" { + continue + } + e.fn(api.ProgressResponse{ + Status: fmt.Sprintf("creating embeddings for file %s", filePath), + Digest: fileDigest, + Total: len(data) - 1, + Completed: i, + }) + retry := 0 + generate: + if retry > 3 { + log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) + continue + } + embed, err := llm.Embedding(d) + if err != nil { + log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err) + retry++ + goto generate + } + // Check for NaN and Inf in the embedding, which can't be stored + for _, value := range embed { + if math.IsNaN(value) || math.IsInf(value, 0) { + log.Printf("reloading model, embedding contains NaN or Inf") + // reload the model to get a new embedding + llm, err = llama.New(model.ModelPath, e.opts) + if err != nil { + return nil, fmt.Errorf("load model to generate embeddings: %v", err) + } + retry++ + goto generate + } + } + embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed}) + } + + b, err := json.Marshal(embeddings) + if err != nil { + return nil, fmt.Errorf("failed to encode embeddings: %w", err) + } + r := bytes.NewReader(b) + + digest, size := GetSHA256Digest(r) + // Reset the position of the reader after calculating the digest + if _, err := r.Seek(0, 0); err != nil { + return nil, fmt.Errorf("could not reset embed reader: %w", err) + } + + layer := &LayerReader{ + Layer: Layer{ + MediaType: "application/vnd.ollama.image.embed", + Digest: digest, + Size: size, + }, + Reader: r, + } + + layers = append(layers, layer) + } + } + return layers, nil +} + func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { j := 0 for _, l := range layers { @@ -449,8 +604,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) { return newLayer, nil } -// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json -func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { +// formatParams converts specified parameter options to their correct types +func formatParams(params map[string][]string) (map[string]interface{}, error) { opts := api.Options{} valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct @@ -504,12 +659,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { } } - bts, err := json.Marshal(out) - if err != nil { - return nil, err - } - - return bytes.NewReader(bts), nil + return out, nil } func getLayerDigests(layers []*LayerReader) ([]string, error) { @@ -1042,7 +1192,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func for { fn(api.ProgressResponse{ - Status: fmt.Sprintf("downloading %s", digest), + Status: fmt.Sprintf("pulling %s...", digest[7:19]), Digest: digest, Total: int(total), Completed: int(completed), diff --git a/server/routes.go b/server/routes.go index 83afef1a..2a880aaa 100644 --- a/server/routes.go +++ b/server/routes.go @@ -20,12 +20,14 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" + "github.com/jmorganca/ollama/vector" ) var loaded struct { mu sync.Mutex - llm *llama.LLM + llm *llama.LLM + Embeddings []vector.Embedding expireAt time.Time expireTimer *time.Timer @@ -72,6 +74,11 @@ func GenerateHandler(c *gin.Context) { loaded.digest = "" } + if model.Embeddings != nil && len(model.Embeddings) > 0 { + opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work + loaded.Embeddings = model.Embeddings + } + llm, err := llama.New(model.ModelPath, opts) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -82,7 +89,6 @@ func GenerateHandler(c *gin.Context) { loaded.digest = model.Digest loaded.options = opts } - sessionDuration := 5 * time.Minute loaded.expireAt = time.Now().Add(sessionDuration) diff --git a/vector/store.go b/vector/store.go new file mode 100644 index 00000000..510470d8 --- /dev/null +++ b/vector/store.go @@ -0,0 +1,69 @@ +package vector + +import ( + "container/heap" + "sort" + + "gonum.org/v1/gonum/mat" +) + +type Embedding struct { + Vector []float64 // the embedding vector + Data string // the data represted by the embedding +} + +type EmbeddingSimilarity struct { + Embedding Embedding // the embedding that was used to calculate the similarity + Similarity float64 // the similarity between the embedding and the query +} + +type Heap []EmbeddingSimilarity + +func (h Heap) Len() int { return len(h) } +func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity } +func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *Heap) Push(e any) { + *h = append(*h, e.(EmbeddingSimilarity)) +} + +func (h *Heap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors. +// This value will range from -1 to 1, where 1 means the vectors are identical. +func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 { + dotProduct := mat.Dot(vec1, vec2) + norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2) + + if norms == 0 { + return 0 + } + return dotProduct / norms +} + +func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity { + h := &Heap{} + heap.Init(h) + for _, emb := range embeddings { + similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector)) + heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity}) + if h.Len() > k { + heap.Pop(h) + } + } + + topK := make([]EmbeddingSimilarity, 0, h.Len()) + for h.Len() > 0 { + topK = append(topK, heap.Pop(h).(EmbeddingSimilarity)) + } + sort.Slice(topK, func(i, j int) bool { + return topK[i].Similarity > topK[j].Similarity + }) + + return topK +}