embed text document in modelfile

This commit is contained in:
Bruce MacDonald 2023-08-09 10:26:19 -04:00 committed by GitHub
commit 7a5f3616fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 371 additions and 52 deletions

View file

@ -276,6 +276,7 @@ func DefaultOptions() Options {
UseMLock: false, UseMLock: false,
RopeFrequencyBase: 10000.0, RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0, RopeFrequencyScale: 1.0,
EmbeddingOnly: true,
RepeatLastN: 64, RepeatLastN: 64,
RepeatPenalty: 1.1, RepeatPenalty: 1.1,

View file

@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop() spinner.Stop()
} }
currentDigest = resp.Digest currentDigest = resp.Digest
bar = progressbar.DefaultBytes( switch {
int64(resp.Total), case strings.Contains(resp.Status, "embeddings"):
fmt.Sprintf("pulling %s...", resp.Digest[7:19]), bar = progressbar.Default(int64(resp.Total), resp.Status)
) bar.Set(resp.Completed)
default:
bar.Set(resp.Completed) // pulling
bar = progressbar.DefaultBytes(
int64(resp.Total),
resp.Status,
)
bar.Set(resp.Completed)
}
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else { } else {

View file

@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama.
- [FROM (Required)](#from-required) - [FROM (Required)](#from-required)
- [Build from llama2](#build-from-llama2) - [Build from llama2](#build-from-llama2)
- [Build from a bin file](#build-from-a-bin-file) - [Build from a bin file](#build-from-a-bin-file)
- [EMBED](#embed)
- [PARAMETER](#parameter) - [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values) - [Valid Parameters and Values](#valid-parameters-and-values)
- [TEMPLATE](#template) - [TEMPLATE](#template)
@ -88,6 +89,15 @@ FROM ./ollama-model.bin
This bin file location should be specified as an absolute path or relative to the Modelfile location. This bin file location should be specified as an absolute path or relative to the Modelfile location.
### EMBED
The EMBED instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer.
```
FROM <model name>:<tag>
EMBED <file path>
```
### PARAMETER ### PARAMETER
The `PARAMETER` instruction defines a parameter that can be set when the model is run. The `PARAMETER` instruction defines a parameter that can be set when the model is run.
@ -163,4 +173,4 @@ LICENSE """
## Notes ## Notes
- the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments. - the **modelfile is not case sensitive**. In the examples, we use uppercase for instructions to make it easier to distinguish it from arguments.
- Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable. - Instructions can be in any order. In the examples, we start with FROM instruction to keep it easily readable.

1
go.mod
View file

@ -42,6 +42,7 @@ require (
golang.org/x/sys v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0 golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
gonum.org/v1/gonum v0.13.0
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

2
go.sum
View file

@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View file

@ -85,6 +85,7 @@ llama_token llama_sample(
} }
*/ */
import "C" import "C"
import ( import (
"bytes" "bytes"
"embed" "embed"
@ -93,6 +94,7 @@ import (
"io" "io"
"log" "log"
"os" "os"
"reflect"
"strings" "strings"
"sync" "sync"
"unicode/utf8" "unicode/utf8"
@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil return token, nil
} }
func (llm *LLM) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
tokens := llm.tokenize(input)
if tokens == nil {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
if n <= 0 {
return nil, errors.New("llama: no embeddings generated")
}
embedPtr := C.llama_get_embeddings(llm.ctx)
if embedPtr == nil {
return nil, errors.New("llama: embedding retrieval failed")
}
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
}

View file

@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(fields[1]) command.Args = string(fields[1])
// copy command for validation // copy command for validation
modelCommand = command modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
command.Name = string(bytes.ToLower(fields[0])) command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1]) command.Args = string(fields[1])
case "PARAMETER": case "PARAMETER":

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
@ -9,6 +10,7 @@ import (
"html/template" "html/template"
"io" "io"
"log" "log"
"math"
"net/http" "net/http"
"os" "os"
"path" "path"
@ -18,7 +20,9 @@ import (
"strings" "strings"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
) )
type RegistryOptions struct { type RegistryOptions struct {
@ -28,15 +32,16 @@ type RegistryOptions struct {
} }
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
ModelPath string ModelPath string
Template string Template string
System string System string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Embeddings []vector.Embedding
} }
func (m *Model) Prompt(request api.GenerateRequest) (string, error) { func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
t := m.Template t := m.Template
if request.Template != "" { if request.Template != "" {
t = request.Template t = request.Template
@ -51,6 +56,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
First bool First bool
System string System string
Prompt string Prompt string
Embed string
// deprecated: versions <= 0.0.7 used this to omit the system prompt // deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int Context []int
@ -60,6 +66,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
vars.System = m.System vars.System = m.System
vars.Prompt = request.Prompt vars.Prompt = request.Prompt
vars.Context = request.Context vars.Context = request.Context
vars.Embed = embedding
if request.System != "" { if request.System != "" {
vars.System = request.System vars.System = request.System
@ -157,6 +164,16 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
case "application/vnd.ollama.image.embed":
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open file: %s", filename)
}
defer file.Close()
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
return nil, err
}
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -195,6 +212,26 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func filenameWithPath(path, f string) (string, error) {
// if filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(f, "~/") {
parts := strings.Split(f, "/")
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to open file: %v", err)
}
f = filepath.Join(home, filepath.Join(parts[1:]...))
}
// if filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(f) {
f = filepath.Join(filepath.Dir(path), f)
}
return f, nil
}
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
mf, err := os.Open(path) mf, err := os.Open(path)
if err != nil { if err != nil {
@ -211,33 +248,20 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mf, err := GetManifest(ParseModelPath(c.Args)) mf, err := GetManifest(ParseModelPath(c.Args))
if err != nil { if err != nil {
fp := c.Args modelFile, err := filenameWithPath(path, c.Args)
if err != nil {
// If filePath starts with ~/, replace it with the user's home directory. return err
if strings.HasPrefix(fp, "~/") {
parts := strings.Split(fp, "/")
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
fp = filepath.Join(home, filepath.Join(parts[1:]...))
} }
if _, err := os.Stat(modelFile); err != nil {
// If filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(fp) {
fp = filepath.Join(filepath.Dir(path), fp)
}
if _, err := os.Stat(fp); err != nil {
// the model file does not exist, try pulling it // the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"}) fn(api.ProgressResponse{Status: "pulling model file"})
@ -248,15 +272,13 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err) return fmt.Errorf("failed to open file after pull: %v", err)
} }
} else { } else {
return err return err
} }
} else { } else {
// create a model from this specified file // create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"}) fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
file, err := os.Open(fp)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return fmt.Errorf("failed to open file: %v", err)
} }
@ -280,9 +302,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
layers = append(layers, newLayer) layers = append(layers, newLayer)
} }
} }
case "embed":
embedFilePath, err := filenameWithPath(path, c.Args)
if err != nil {
return err
}
embed.files = append(embed.files, embedFilePath)
case "license": case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
@ -315,18 +342,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameter layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
paramData, err := paramsToReader(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return fmt.Errorf("couldn't create params json: %v", err)
} }
l, err := CreateLayer(paramData)
bts, err := json.Marshal(formattedParams)
if err != nil {
return err
}
l, err := CreateLayer(bytes.NewReader(bts))
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return fmt.Errorf("failed to create layer: %v", err)
} }
l.MediaType = "application/vnd.ollama.image.params" l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l) layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions()
embed.opts.FromMap(formattedParams)
} }
// generate the embedding layers
embeddingLayers, err := embeddingLayers(embed)
if err != nil {
return err
}
layers = append(layers, embeddingLayers...)
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
if err != nil { if err != nil {
return err return err
@ -361,6 +405,138 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
return nil return nil
} }
type EmbeddingParams struct {
model string
opts api.Options
files []string // paths to files to embed
fn func(resp api.ProgressResponse)
}
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{}
if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil {
if os.IsNotExist(err) {
// this is a model name rather than the file
model, err := GetModel(e.model)
if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
}
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
}
}
e.opts.EmbeddingOnly = true
llm, err := llama.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
defer func() {
if llm != nil {
llm.Close()
}
}()
addedFiles := make(map[string]bool) // keep track of files that have already been added
for _, filePattern := range e.files {
matchingFiles, err := filepath.Glob(filePattern)
if err != nil {
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
}
for _, filePath := range matchingFiles {
if addedFiles[filePath] {
continue
}
addedFiles[filePath] = true
// TODO: check file type
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("could not open embed file: %w", err)
}
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanLines)
data := []string{}
for scanner.Scan() {
data = append(data, scanner.Text())
}
f.Close()
// the digest of the file is set here so that the client knows a new operation is in progress
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
embeddings := []vector.Embedding{}
for i, d := range data {
if strings.TrimSpace(d) == "" {
continue
}
e.fn(api.ProgressResponse{
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
Digest: fileDigest,
Total: len(data) - 1,
Completed: i,
})
retry := 0
generate:
if retry > 3 {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
}
embed, err := llm.Embedding(d)
if err != nil {
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
retry++
goto generate
}
// Check for NaN and Inf in the embedding, which can't be stored
for _, value := range embed {
if math.IsNaN(value) || math.IsInf(value, 0) {
log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
llm.Close()
llm, err = llama.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
retry++
goto generate
}
}
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
}
b, err := json.Marshal(embeddings)
if err != nil {
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
}
r := bytes.NewReader(b)
digest, size := GetSHA256Digest(r)
// Reset the position of the reader after calculating the digest
if _, err := r.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("could not reset embed reader: %w", err)
}
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.ollama.image.embed",
Digest: digest,
Size: size,
},
Reader: r,
}
layers = append(layers, layer)
}
}
}
return layers, nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
j := 0 j := 0
for _, l := range layers { for _, l := range layers {
@ -449,8 +625,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json // formatParams converts specified parameter options to their correct types
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{} opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@ -504,12 +680,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
} }
} }
bts, err := json.Marshal(out) return out, nil
if err != nil {
return nil, err
}
return bytes.NewReader(bts), nil
} }
func getLayerDigests(layers []*LayerReader) ([]string, error) { func getLayerDigests(layers []*LayerReader) ([]string, error) {
@ -1042,7 +1213,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
for { for {
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", digest), Status: fmt.Sprintf("pulling %s...", digest[7:19]),
Digest: digest, Digest: digest,
Total: int(total), Total: int(total),
Completed: int(completed), Completed: int(completed),

View file

@ -17,15 +17,18 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/vector"
) )
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
llm *llama.LLM llm *llama.LLM
Embeddings []vector.Embedding
expireAt time.Time expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
@ -72,6 +75,11 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = "" loaded.digest = ""
} }
if model.Embeddings != nil && len(model.Embeddings) > 0 {
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
loaded.Embeddings = model.Embeddings
}
llm, err := llama.New(model.ModelPath, opts) llm, err := llama.New(model.ModelPath, opts)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -101,7 +109,6 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = model.Digest loaded.digest = model.Digest
loaded.options = opts loaded.options = opts
} }
sessionDuration := 5 * time.Minute sessionDuration := 5 * time.Minute
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)
@ -127,7 +134,22 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt, err := model.Prompt(req) embedding := ""
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
for _, e := range topK {
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
}
}
prompt, err := model.Prompt(req, embedding)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return

69
vector/store.go Normal file
View file

@ -0,0 +1,69 @@
package vector
import (
"container/heap"
"sort"
"gonum.org/v1/gonum/mat"
)
type Embedding struct {
Vector []float64 // the embedding vector
Data string // the data represted by the embedding
}
type EmbeddingSimilarity struct {
Embedding Embedding // the embedding that was used to calculate the similarity
Similarity float64 // the similarity between the embedding and the query
}
type Heap []EmbeddingSimilarity
func (h Heap) Len() int { return len(h) }
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *Heap) Push(e any) {
*h = append(*h, e.(EmbeddingSimilarity))
}
func (h *Heap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
// This value will range from -1 to 1, where 1 means the vectors are identical.
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
dotProduct := mat.Dot(vec1, vec2)
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
if norms == 0 {
return 0
}
return dotProduct / norms
}
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
h := &Heap{}
heap.Init(h)
for _, emb := range embeddings {
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
if h.Len() > k {
heap.Pop(h)
}
}
topK := make([]EmbeddingSimilarity, 0, h.Len())
for h.Len() > 0 {
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
}
sort.Slice(topK, func(i, j int) bool {
return topK[i].Similarity > topK[j].Similarity
})
return topK
}