embed text document in modelfile
This commit is contained in:
parent
34a13a9d05
commit
a6f6d18f83
8 changed files with 330 additions and 59 deletions
18
cmd/cmd.go
18
cmd/cmd.go
|
@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
bar = progressbar.DefaultBytes(
|
switch {
|
||||||
int64(resp.Total),
|
case strings.Contains(resp.Status, "embeddings"):
|
||||||
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
|
bar = progressbar.Default(int64(resp.Total), resp.Status)
|
||||||
)
|
bar.Set(resp.Completed)
|
||||||
|
default:
|
||||||
bar.Set(resp.Completed)
|
// pulling
|
||||||
|
bar = progressbar.DefaultBytes(
|
||||||
|
int64(resp.Total),
|
||||||
|
resp.Status,
|
||||||
|
)
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
}
|
||||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
bar.Set(resp.Completed)
|
bar.Set(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -42,6 +42,7 @@ require (
|
||||||
golang.org/x/sys v0.10.0 // indirect
|
golang.org/x/sys v0.10.0 // indirect
|
||||||
golang.org/x/term v0.10.0
|
golang.org/x/term v0.10.0
|
||||||
golang.org/x/text v0.10.0 // indirect
|
golang.org/x/text v0.10.0 // indirect
|
||||||
|
gonum.org/v1/gonum v0.13.0
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
||||||
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
|
||||||
|
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||||
|
|
|
@ -85,6 +85,7 @@ llama_token llama_sample(
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"embed"
|
"embed"
|
||||||
|
@ -93,6 +94,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -414,3 +416,38 @@ func (llm *LLM) next() (C.llama_token, error) {
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *LLM) Embedding(input string) ([]float64, error) {
|
||||||
|
if !llm.EmbeddingOnly {
|
||||||
|
return nil, errors.New("llama: embedding not enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := llm.tokenize(input)
|
||||||
|
if tokens == nil {
|
||||||
|
return nil, errors.New("llama: tokenize embedding")
|
||||||
|
}
|
||||||
|
|
||||||
|
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
|
||||||
|
if retval != 0 {
|
||||||
|
return nil, errors.New("llama: eval")
|
||||||
|
}
|
||||||
|
|
||||||
|
n := int(C.llama_n_embd(llm.ctx))
|
||||||
|
if n <= 0 {
|
||||||
|
return nil, errors.New("llama: no embeddings generated")
|
||||||
|
}
|
||||||
|
|
||||||
|
embedPtr := C.llama_get_embeddings(llm.ctx)
|
||||||
|
if embedPtr == nil {
|
||||||
|
return nil, errors.New("llama: embedding retrieval failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
header := reflect.SliceHeader{
|
||||||
|
Data: uintptr(unsafe.Pointer(embedPtr)),
|
||||||
|
Len: n,
|
||||||
|
Cap: n,
|
||||||
|
}
|
||||||
|
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
|
||||||
|
|
||||||
|
return embedSlice, nil
|
||||||
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
// copy command for validation
|
// copy command for validation
|
||||||
modelCommand = command
|
modelCommand = command
|
||||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
command.Name = string(bytes.ToLower(fields[0]))
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
case "PARAMETER":
|
case "PARAMETER":
|
||||||
|
|
250
server/images.go
250
server/images.go
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -18,7 +20,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/llama"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
|
"github.com/jmorganca/ollama/vector"
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RegistryOptions struct {
|
type RegistryOptions struct {
|
||||||
|
@ -28,12 +33,13 @@ type RegistryOptions struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ModelPath string
|
ModelPath string
|
||||||
Template string
|
Template string
|
||||||
System string
|
System string
|
||||||
Digest string
|
Digest string
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
|
Embeddings []vector.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
|
@ -51,6 +57,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
First bool
|
First bool
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
Embed string
|
||||||
|
|
||||||
// deprecated: versions <= 0.0.7 used this to omit the system prompt
|
// deprecated: versions <= 0.0.7 used this to omit the system prompt
|
||||||
Context []int
|
Context []int
|
||||||
|
@ -65,6 +72,21 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
vars.System = request.System
|
vars.System = request.System
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(m.Embeddings) > 0 {
|
||||||
|
promptEmbed, err := loaded.llm.Embedding(request.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get embedding for prompt: %v", err)
|
||||||
|
}
|
||||||
|
// TODO: set embed_top from specified parameters in modelfile
|
||||||
|
embed_top := 3
|
||||||
|
embed := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
||||||
|
toEmbed := ""
|
||||||
|
for _, e := range embed {
|
||||||
|
toEmbed = fmt.Sprintf("%s %s", toEmbed, e.Embedding.Data)
|
||||||
|
}
|
||||||
|
vars.Embed = toEmbed
|
||||||
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -157,6 +179,16 @@ func GetModel(name string) (*Model, error) {
|
||||||
switch layer.MediaType {
|
switch layer.MediaType {
|
||||||
case "application/vnd.ollama.image.model":
|
case "application/vnd.ollama.image.model":
|
||||||
model.ModelPath = filename
|
model.ModelPath = filename
|
||||||
|
case "application/vnd.ollama.image.embed":
|
||||||
|
file, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open file: %s", filename)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -195,6 +227,26 @@ func GetModel(name string) (*Model, error) {
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func filenameWithPath(path, f string) (string, error) {
|
||||||
|
// if filePath starts with ~/, replace it with the user's home directory.
|
||||||
|
if strings.HasPrefix(f, "~/") {
|
||||||
|
parts := strings.Split(f, "/")
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f = filepath.Join(home, filepath.Join(parts[1:]...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// if filePath is not an absolute path, make it relative to the modelfile path
|
||||||
|
if !filepath.IsAbs(f) {
|
||||||
|
f = filepath.Join(filepath.Dir(path), f)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
|
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||||
mf, err := os.Open(path)
|
mf, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -211,52 +263,37 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
|
|
||||||
var layers []*LayerReader
|
var layers []*LayerReader
|
||||||
params := make(map[string][]string)
|
params := make(map[string][]string)
|
||||||
|
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
|
||||||
for _, c := range commands {
|
for _, c := range commands {
|
||||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
fn(api.ProgressResponse{Status: "looking for model"})
|
fn(api.ProgressResponse{Status: "looking for model"})
|
||||||
|
embed.model = c.Args
|
||||||
mf, err := GetManifest(ParseModelPath(c.Args))
|
mf, err := GetManifest(ParseModelPath(c.Args))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp := c.Args
|
modelFile, err := filenameWithPath(path, c.Args)
|
||||||
|
if err != nil {
|
||||||
// If filePath starts with ~/, replace it with the user's home directory.
|
return err
|
||||||
if strings.HasPrefix(fp, "~/") {
|
|
||||||
parts := strings.Split(fp, "/")
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fp = filepath.Join(home, filepath.Join(parts[1:]...))
|
|
||||||
}
|
}
|
||||||
|
if _, err := os.Stat(modelFile); err != nil {
|
||||||
// If filePath is not an absolute path, make it relative to the modelfile path
|
|
||||||
if !filepath.IsAbs(fp) {
|
|
||||||
fp = filepath.Join(filepath.Dir(path), fp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(fp); err != nil {
|
|
||||||
// the model file does not exist, try pulling it
|
// the model file does not exist, try pulling it
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
fn(api.ProgressResponse{Status: "pulling model file"})
|
fn(api.ProgressResponse{Status: "pulling model file"})
|
||||||
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
|
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
mf, err = GetManifest(ParseModelPath(modelFile))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// create a model from this specified file
|
// create a model from this specified file
|
||||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||||
|
file, err := os.Open(modelFile)
|
||||||
file, err := os.Open(fp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open file: %v", err)
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -280,19 +317,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
layers = append(layers, newLayer)
|
layers = append(layers, newLayer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "license":
|
case "embed":
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
// TODO: support entire directories here
|
||||||
// remove the prompt layer if one exists
|
embedFilePath, err := filenameWithPath(path, c.Args)
|
||||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
|
||||||
|
|
||||||
layer, err := CreateLayer(strings.NewReader(c.Args))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
embed.files = append(embed.files, embedFilePath)
|
||||||
layer.MediaType = mediaType
|
case "license", "template", "system", "prompt":
|
||||||
layers = append(layers, layer)
|
|
||||||
case "template", "system", "prompt":
|
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
// remove the prompt layer if one exists
|
// remove the prompt layer if one exists
|
||||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
|
@ -315,18 +347,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
fn(api.ProgressResponse{Status: "creating parameter layer"})
|
fn(api.ProgressResponse{Status: "creating parameter layer"})
|
||||||
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
|
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
|
||||||
paramData, err := paramsToReader(params)
|
formattedParams, err := formatParams(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't create params json: %v", err)
|
return fmt.Errorf("couldn't create params json: %v", err)
|
||||||
}
|
}
|
||||||
l, err := CreateLayer(paramData)
|
|
||||||
|
bts, err := json.Marshal(formattedParams)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := CreateLayer(bytes.NewReader(bts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create layer: %v", err)
|
return fmt.Errorf("failed to create layer: %v", err)
|
||||||
}
|
}
|
||||||
l.MediaType = "application/vnd.ollama.image.params"
|
l.MediaType = "application/vnd.ollama.image.params"
|
||||||
layers = append(layers, l)
|
layers = append(layers, l)
|
||||||
|
|
||||||
|
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
|
||||||
|
embed.opts = api.DefaultOptions()
|
||||||
|
embed.opts.FromMap(formattedParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate the embedding layers
|
||||||
|
embeddingLayers, err := embeddingLayers(embed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
layers = append(layers, embeddingLayers...)
|
||||||
|
|
||||||
digests, err := getLayerDigests(layers)
|
digests, err := getLayerDigests(layers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -361,6 +410,112 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingParams struct {
|
||||||
|
model string
|
||||||
|
opts api.Options
|
||||||
|
files []string // paths to files to embed
|
||||||
|
fn func(resp api.ProgressResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
|
||||||
|
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
|
layers := []*LayerReader{}
|
||||||
|
if len(e.files) > 0 {
|
||||||
|
model, err := GetModel(e.model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.opts.EmbeddingOnly = true
|
||||||
|
llm, err := llama.New(model.ModelPath, e.opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, filePath := range e.files {
|
||||||
|
// TODO: check if txt file type
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not open embed file: %w", err)
|
||||||
|
}
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
data := []string{}
|
||||||
|
for scanner.Scan() {
|
||||||
|
data = append(data, scanner.Text())
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
// the digest of the file is set here so that the client knows a new operation is in progress
|
||||||
|
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
|
||||||
|
|
||||||
|
embeddings := []vector.Embedding{}
|
||||||
|
for i, d := range data {
|
||||||
|
if strings.TrimSpace(d) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.fn(api.ProgressResponse{
|
||||||
|
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
||||||
|
Digest: fileDigest,
|
||||||
|
Total: len(data) - 1,
|
||||||
|
Completed: i,
|
||||||
|
})
|
||||||
|
retry := 0
|
||||||
|
generate:
|
||||||
|
if retry > 3 {
|
||||||
|
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embed, err := llm.Embedding(d)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
|
||||||
|
retry++
|
||||||
|
goto generate
|
||||||
|
}
|
||||||
|
// Check for NaN and Inf in the embedding, which can't be stored
|
||||||
|
for _, value := range embed {
|
||||||
|
if math.IsNaN(value) || math.IsInf(value, 0) {
|
||||||
|
log.Printf("reloading model, embedding contains NaN or Inf")
|
||||||
|
// reload the model to get a new embedding
|
||||||
|
llm, err = llama.New(model.ModelPath, e.opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
|
}
|
||||||
|
retry++
|
||||||
|
goto generate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(embeddings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
|
||||||
|
}
|
||||||
|
r := bytes.NewReader(b)
|
||||||
|
|
||||||
|
digest, size := GetSHA256Digest(r)
|
||||||
|
// Reset the position of the reader after calculating the digest
|
||||||
|
if _, err := r.Seek(0, 0); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not reset embed reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer := &LayerReader{
|
||||||
|
Layer: Layer{
|
||||||
|
MediaType: "application/vnd.ollama.image.embed",
|
||||||
|
Digest: digest,
|
||||||
|
Size: size,
|
||||||
|
},
|
||||||
|
Reader: r,
|
||||||
|
}
|
||||||
|
|
||||||
|
layers = append(layers, layer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return layers, nil
|
||||||
|
}
|
||||||
|
|
||||||
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
||||||
j := 0
|
j := 0
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
|
@ -449,8 +604,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||||
return newLayer, nil
|
return newLayer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
|
// formatParams converts specified parameter options to their correct types
|
||||||
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
func formatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
opts := api.Options{}
|
opts := api.Options{}
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
|
@ -504,12 +659,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bts, err := json.Marshal(out)
|
return out, nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes.NewReader(bts), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
||||||
|
@ -1042,7 +1192,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
|
||||||
|
|
||||||
for {
|
for {
|
||||||
fn(api.ProgressResponse{
|
fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("downloading %s", digest),
|
Status: fmt.Sprintf("pulling %s...", digest[7:19]),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Total: int(total),
|
Total: int(total),
|
||||||
Completed: int(completed),
|
Completed: int(completed),
|
||||||
|
|
|
@ -20,12 +20,14 @@ import (
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
|
"github.com/jmorganca/ollama/vector"
|
||||||
)
|
)
|
||||||
|
|
||||||
var loaded struct {
|
var loaded struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
llm *llama.LLM
|
llm *llama.LLM
|
||||||
|
Embeddings []vector.Embedding
|
||||||
|
|
||||||
expireAt time.Time
|
expireAt time.Time
|
||||||
expireTimer *time.Timer
|
expireTimer *time.Timer
|
||||||
|
@ -72,6 +74,11 @@ func GenerateHandler(c *gin.Context) {
|
||||||
loaded.digest = ""
|
loaded.digest = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||||
|
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
|
||||||
|
loaded.Embeddings = model.Embeddings
|
||||||
|
}
|
||||||
|
|
||||||
llm, err := llama.New(model.ModelPath, opts)
|
llm, err := llama.New(model.ModelPath, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
@ -82,7 +89,6 @@ func GenerateHandler(c *gin.Context) {
|
||||||
loaded.digest = model.Digest
|
loaded.digest = model.Digest
|
||||||
loaded.options = opts
|
loaded.options = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := 5 * time.Minute
|
sessionDuration := 5 * time.Minute
|
||||||
|
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
|
|
69
vector/store.go
Normal file
69
vector/store.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Vector []float64 // the embedding vector
|
||||||
|
Data string // the data represted by the embedding
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingSimilarity struct {
|
||||||
|
Embedding Embedding // the embedding that was used to calculate the similarity
|
||||||
|
Similarity float64 // the similarity between the embedding and the query
|
||||||
|
}
|
||||||
|
|
||||||
|
type Heap []EmbeddingSimilarity
|
||||||
|
|
||||||
|
func (h Heap) Len() int { return len(h) }
|
||||||
|
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
|
||||||
|
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||||
|
func (h *Heap) Push(e any) {
|
||||||
|
*h = append(*h, e.(EmbeddingSimilarity))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Heap) Pop() interface{} {
|
||||||
|
old := *h
|
||||||
|
n := len(old)
|
||||||
|
x := old[n-1]
|
||||||
|
*h = old[0 : n-1]
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
|
||||||
|
// This value will range from -1 to 1, where 1 means the vectors are identical.
|
||||||
|
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
|
||||||
|
dotProduct := mat.Dot(vec1, vec2)
|
||||||
|
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
|
||||||
|
|
||||||
|
if norms == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return dotProduct / norms
|
||||||
|
}
|
||||||
|
|
||||||
|
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
|
||||||
|
h := &Heap{}
|
||||||
|
heap.Init(h)
|
||||||
|
for _, emb := range embeddings {
|
||||||
|
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
|
||||||
|
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
|
||||||
|
if h.Len() > k {
|
||||||
|
heap.Pop(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
topK := make([]EmbeddingSimilarity, 0, h.Len())
|
||||||
|
for h.Len() > 0 {
|
||||||
|
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
|
||||||
|
}
|
||||||
|
sort.Slice(topK, func(i, j int) bool {
|
||||||
|
return topK[i].Similarity > topK[j].Similarity
|
||||||
|
})
|
||||||
|
|
||||||
|
return topK
|
||||||
|
}
|
Loading…
Reference in a new issue