Merge branch 'main' into add_oterm

This commit is contained in:
Michael Yang 2023-10-16 15:51:44 -07:00 committed by GitHub
commit 785b4eb5bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 80 additions and 354 deletions

View file

@ -231,3 +231,4 @@ curl -X POST http://localhost:11434/api/generate -d '{
- [Dumbar](https://github.com/JerrySievert/Dumbar) - [Dumbar](https://github.com/JerrySievert/Dumbar)
- [Emacs client](https://github.com/zweifisch/ollama) - [Emacs client](https://github.com/zweifisch/ollama)
- [oterm](https://github.com/ggozad/oterm) - [oterm](https://github.com/ggozad/oterm)
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)

View file

@ -14,6 +14,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -127,7 +128,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 512 * 1000 // 512KB const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer

View file

@ -78,18 +78,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop() spinner.Stop()
} }
currentDigest = resp.Digest currentDigest = resp.Digest
switch { // pulling
case strings.Contains(resp.Status, "embeddings"): bar = progressbar.DefaultBytes(
bar = progressbar.Default(resp.Total, resp.Status) resp.Total,
bar.Set64(resp.Completed) resp.Status,
default: )
// pulling bar.Set64(resp.Completed)
bar = progressbar.DefaultBytes(
resp.Total,
resp.Status,
)
bar.Set64(resp.Completed)
}
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set64(resp.Completed) bar.Set64(resp.Completed)
} else { } else {
@ -694,7 +688,12 @@ func generateInteractive(cmd *cobra.Command, model string) error {
case strings.HasPrefix(line, "/show"): case strings.HasPrefix(line, "/show"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
resp, err := server.GetModelInfo(model) client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
if err != nil { if err != nil {
fmt.Println("error: couldn't get model") fmt.Println("error: couldn't get model")
return err return err

View file

@ -12,7 +12,6 @@ A model file is the blueprint to create and share models with Ollama.
- [FROM (Required)](#from-required) - [FROM (Required)](#from-required)
- [Build from llama2](#build-from-llama2) - [Build from llama2](#build-from-llama2)
- [Build from a bin file](#build-from-a-bin-file) - [Build from a bin file](#build-from-a-bin-file)
- [EMBED](#embed)
- [PARAMETER](#parameter) - [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values) - [Valid Parameters and Values](#valid-parameters-and-values)
- [TEMPLATE](#template) - [TEMPLATE](#template)
@ -91,17 +90,6 @@ FROM ./ollama-model.bin
This bin file location should be specified as an absolute path or relative to the `Modelfile` location. This bin file location should be specified as an absolute path or relative to the `Modelfile` location.
### EMBED
The `EMBED` instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer. Note that currently only text files are supported, formatted with each line as one embedding.
```modelfile
FROM <model name>:<tag>
EMBED <file path>.txt
EMBED <different file path>.txt
EMBED <path to directory>/*.txt
```
### PARAMETER ### PARAMETER
The `PARAMETER` instruction defines a parameter that can be set when the model is run. The `PARAMETER` instruction defines a parameter that can be set when the model is run.

View file

@ -2,14 +2,21 @@ package format
import "fmt" import "fmt"
const (
Byte = 1
KiloByte = Byte * 1000
MegaByte = KiloByte * 1000
GigaByte = MegaByte * 1000
)
func HumanBytes(b int64) string { func HumanBytes(b int64) string {
switch { switch {
case b > 1000*1000*1000: case b > GigaByte:
return fmt.Sprintf("%d GB", b/1000/1000/1000) return fmt.Sprintf("%d GB", b/GigaByte)
case b > 1000*1000: case b > MegaByte:
return fmt.Sprintf("%d MB", b/1000/1000) return fmt.Sprintf("%d MB", b/MegaByte)
case b > 1000: case b > KiloByte:
return fmt.Sprintf("%d KB", b/1000) return fmt.Sprintf("%d KB", b/KiloByte)
default: default:
return fmt.Sprintf("%d B", b) return fmt.Sprintf("%d B", b)
} }

1
go.mod
View file

@ -45,7 +45,6 @@ require (
golang.org/x/sys v0.11.0 // indirect golang.org/x/sys v0.11.0 // indirect
golang.org/x/term v0.10.0 golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
gonum.org/v1/gonum v0.13.0
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

2
go.sum
View file

@ -145,8 +145,6 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View file

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
) )
//go:embed llama.cpp/*/build/*/bin/* //go:embed llama.cpp/*/build/*/bin/*
@ -197,7 +198,7 @@ type llama struct {
var errNoGPU = errors.New("nvidia-smi command failed") var errNoGPU = errors.New("nvidia-smi command failed")
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) { func CheckVRAM() (int64, error) {
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits") cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
var stdout bytes.Buffer var stdout bytes.Buffer
@ -207,7 +208,7 @@ func CheckVRAM() (int64, error) {
return 0, errNoGPU return 0, errNoGPU
} }
var free int64 var freeMiB int64
scanner := bufio.NewScanner(&stdout) scanner := bufio.NewScanner(&stdout)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -216,15 +217,16 @@ func CheckVRAM() (int64, error) {
return 0, fmt.Errorf("failed to parse available VRAM: %v", err) return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
} }
free += vram freeMiB += vram
} }
if free*1024*1024 < 2*1000*1000*1000 { freeBytes := freeMiB * 1024 * 1024
if freeBytes < 2*format.GigaByte {
log.Printf("less than 2 GB VRAM available, falling back to CPU only") log.Printf("less than 2 GB VRAM available, falling back to CPU only")
free = 0 freeMiB = 0
} }
return free, nil return freeBytes, nil
} }
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
@ -232,7 +234,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
return opts.NumGPU return opts.NumGPU
} }
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
vramMib, err := CheckVRAM() freeBytes, err := CheckVRAM()
if err != nil { if err != nil {
if err.Error() != "nvidia-smi command failed" { if err.Error() != "nvidia-smi command failed" {
log.Print(err.Error()) log.Print(err.Error())
@ -241,15 +243,13 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
return 0 return 0
} }
freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
// Calculate bytes per layer // Calculate bytes per layer
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
bytesPerLayer := fileSizeBytes / numLayer bytesPerLayer := fileSizeBytes / numLayer
// max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
layers := int(freeVramBytes/bytesPerLayer) * 92 / 100 layers := int(freeBytes/bytesPerLayer) * 92 / 100
log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers) log.Printf("%d MiB VRAM available, loading up to %d GPU layers", freeBytes, layers)
return layers return layers
} }
@ -509,7 +509,7 @@ type PredictRequest struct {
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
} }
const maxBufferSize = 512 * 1000 // 512KB const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext) prevConvo, err := llm.Decode(ctx, prevContext)

View file

@ -10,6 +10,7 @@ import (
"github.com/pbnjay/memory" "github.com/pbnjay/memory"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
) )
type LLM interface { type LLM interface {
@ -55,39 +56,30 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
opts.NumGPU = 0 opts.NumGPU = 0
} }
} }
}
totalResidentMemory := memory.TotalMemory() var requiredMemory int64
switch ggml.ModelType() { var f16Multiplier int64 = 2
case "3B", "7B":
if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 { switch ggml.ModelType() {
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory") case "3B", "7B":
} else if totalResidentMemory < 8*1000*1000 { requiredMemory = 8 * format.GigaByte
return nil, fmt.Errorf("model requires at least 8 GB of memory") case "13B":
requiredMemory = 16 * format.GigaByte
case "30B", "34B", "40B":
requiredMemory = 32 * format.GigaByte
case "65B", "70B":
requiredMemory = 64 * format.GigaByte
case "180B":
requiredMemory = 128 * format.GigaByte
f16Multiplier = 4
} }
case "13B":
if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 { systemMemory := int64(memory.TotalMemory())
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
} else if totalResidentMemory < 16*1000*1000 { if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
return nil, fmt.Errorf("model requires at least 16 GB of memory") return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory))
} } else if requiredMemory > systemMemory {
case "30B", "34B", "40B": return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory))
if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 {
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
} else if totalResidentMemory < 32*1000*1000 {
return nil, fmt.Errorf("model requires at least 32 GB of memory")
}
case "65B", "70B":
if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 {
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
} else if totalResidentMemory < 64*1000*1000 {
return nil, fmt.Errorf("model requires at least 64 GB of memory")
}
case "180B":
if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 {
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
} else if totalResidentMemory < 128*1000*1000 {
return nil, fmt.Errorf("model requires at least 128GB of memory")
} }
} }

View file

@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(fields[1]) command.Args = string(fields[1])
// copy command for validation // copy command for validation
modelCommand = command modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED", "ADAPTER": case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
command.Name = string(bytes.ToLower(fields[0])) command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1]) command.Args = string(fields[1])
case "PARAMETER": case "PARAMETER":
@ -51,6 +51,8 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Name = string(fields[0]) command.Name = string(fields[0])
command.Args = string(fields[1]) command.Args = string(fields[1])
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands // log a warning for unknown commands

View file

@ -26,7 +26,8 @@ require() {
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.' [ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
case "$(uname -m)" in ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;; x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;; aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;; *) error "Unsupported architecture: $ARCH" ;;

View file

@ -1,7 +1,6 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
@ -26,7 +25,6 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -49,10 +47,9 @@ type Model struct {
Digest string Digest string
ConfigDigest string ConfigDigest string
Options map[string]interface{} Options map[string]interface{}
Embeddings []vector.Embedding
} }
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template t := m.Template
if request.Template != "" { if request.Template != "" {
t = request.Template t = request.Template
@ -67,7 +64,6 @@ func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, e
First bool First bool
System string System string
Prompt string Prompt string
Embed string
// deprecated: versions <= 0.0.7 used this to omit the system prompt // deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int Context []int
@ -77,7 +73,6 @@ func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, e
vars.System = m.System vars.System = m.System
vars.Prompt = request.Prompt vars.Prompt = request.Prompt
vars.Context = request.Context vars.Context = request.Context
vars.Embed = embedding
if request.System != "" { if request.System != "" {
vars.System = request.System vars.System = request.System
@ -190,15 +185,9 @@ func GetModel(name string) (*Model, error) {
model.ModelPath = filename model.ModelPath = filename
model.OriginalModel = layer.From model.OriginalModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
file, err := os.Open(filename) // Deprecated in versions > 0.1.2
if err != nil { // TODO: remove this warning in a future version
return nil, fmt.Errorf("failed to open file: %s", filename) log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
}
defer file.Close()
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
return nil, err
}
case "application/vnd.ollama.image.adapter": case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename) model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.template":
@ -310,13 +299,11 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
var sourceParams map[string]any var sourceParams map[string]any
embed := EmbeddingParams{fn: fn}
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mp := ParseModelPath(c.Args) mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp) mf, _, err := GetManifest(mp)
@ -340,7 +327,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
return err return err
} }
} else { } else {
embed.model = modelFile
// create a model from this specified file // create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"}) fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile) file, err := os.Open(modelFile)
@ -421,12 +407,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
layers = append(layers, newLayer) layers = append(layers, newLayer)
} }
} }
case "embed":
embedFilePath, err := filenameWithPath(path, c.Args)
if err != nil {
return err
}
embed.files = append(embed.files, embedFilePath)
case "adapter": case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
@ -517,18 +497,8 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
} }
l.MediaType = "application/vnd.ollama.image.params" l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l) layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = formattedParams
} }
// generate the embedding layers
embeddingLayers, err := embeddingLayers(workDir, embed)
if err != nil {
return err
}
layers = append(layers, embeddingLayers...)
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
if err != nil { if err != nil {
return err return err
@ -572,146 +542,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
return nil return nil
} }
type EmbeddingParams struct {
model string
opts map[string]interface{}
files []string // paths to files to embed
fn func(resp api.ProgressResponse)
}
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{}
if len(e.files) > 0 {
// check if the model is a file path or a model name
model, err := GetModel(e.model)
if err != nil {
if !strings.Contains(err.Error(), "couldn't open file") {
return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
}
// the model may be a file path, create a model from this file
model = &Model{ModelPath: e.model}
}
if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
// this will be used to check if we already have embeddings for a file
modelInfo, err := os.Stat(model.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to get model file info: %v", err)
}
addedFiles := make(map[string]bool) // keep track of files that have already been added
for _, filePattern := range e.files {
matchingFiles, err := filepath.Glob(filePattern)
if err != nil {
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
}
for _, filePath := range matchingFiles {
if addedFiles[filePath] {
continue
}
addedFiles[filePath] = true
// check if we already have embeddings for this file path
layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
existing, err := existingFileEmbeddings(digest)
if err != nil {
return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
}
// TODO: check file type
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("could not open embed file: %w", err)
}
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanLines)
data := []string{}
for scanner.Scan() {
data = append(data, scanner.Text())
}
f.Close()
// the digest of the file is set here so that the client knows a new operation is in progress
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
embeddings := []vector.Embedding{}
for i, d := range data {
if strings.TrimSpace(d) == "" {
continue
}
e.fn(api.ProgressResponse{
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
Digest: fileDigest,
Total: int64(len(data) - 1),
Completed: int64(i),
})
if len(existing[d]) > 0 {
// already have an embedding for this line
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
continue
}
embed, err := loaded.llm.Embedding(context.Background(), d)
if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
}
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
}
b, err := json.Marshal(embeddings)
if err != nil {
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
}
r := bytes.NewReader(b)
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.ollama.image.embed",
Digest: digest,
Size: r.Size(),
},
Reader: r,
}
layers = append(layers, layer)
}
}
}
return layers, nil
}
// existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
func existingFileEmbeddings(digest string) (map[string][]float64, error) {
path, err := GetBlobsPath(digest)
if err != nil {
return nil, fmt.Errorf("embeddings blobs path: %w", err)
}
existingFileEmbeddings := make(map[string][]float64)
if _, err := os.Stat(path); err == nil {
// already have some embeddings for this file, load embeddings previously generated
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
}
defer file.Close()
existing := []vector.Embedding{}
if err = json.NewDecoder(file).Decode(&existing); err != nil {
return nil, err
}
for _, e := range existing {
existingFileEmbeddings[e.Data] = e.Vector
}
}
return existingFileEmbeddings, nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
return slices.DeleteFunc(layers, func(layer *LayerReader) bool { return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
return layer.MediaType == mediaType return layer.MediaType == mediaType
@ -727,8 +557,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
} }
_, err = os.Stat(fp) _, err = os.Stat(fp)
// note: embed layers are always written since their digest doesnt indicate anything about the contents if os.IsNotExist(err) || force {
if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)}) fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
out, err := os.Create(fp) out, err := os.Create(fp)

View file

@ -12,7 +12,7 @@ func TestModelPrompt(t *testing.T) {
Template: "a{{ .Prompt }}b", Template: "a{{ .Prompt }}b",
Prompt: "<h1>", Prompt: "<h1>",
} }
s, err := m.Prompt(req, "") s, err := m.Prompt(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -23,11 +23,10 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/vector" "github.com/jmorganca/ollama/version"
) )
var mode string = gin.DebugMode var mode string = gin.DebugMode
@ -47,8 +46,7 @@ func init() {
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
llm llm.LLM llm llm.LLM
Embeddings []vector.Embedding
expireAt time.Time expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
@ -90,11 +88,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.digest = "" loaded.digest = ""
} }
if model.Embeddings != nil && len(model.Embeddings) > 0 {
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
loaded.Embeddings = model.Embeddings
}
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
if err != nil { if err != nil {
return err return err
@ -106,12 +99,12 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.options = opts loaded.options = opts
if opts.NumKeep < 0 { if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "") promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil { if err != nil {
return err return err
} }
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "") promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil { if err != nil {
return err return err
} }
@ -195,22 +188,7 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
embedding := "" prompt, err := model.Prompt(req)
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
for _, e := range topK {
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
}
}
prompt, err := model.Prompt(req, embedding)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -611,7 +589,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.Handle(method, "/api/tags", ListModelsHandler) r.Handle(method, "/api/tags", ListModelsHandler)
} }
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
s := &http.Server{ s := &http.Server{
Handler: r, Handler: r,
} }

View file

@ -1,69 +0,0 @@
package vector
import (
"container/heap"
"sort"
"gonum.org/v1/gonum/mat"
)
type Embedding struct {
Vector []float64 // the embedding vector
Data string // the data represted by the embedding
}
type EmbeddingSimilarity struct {
Embedding Embedding // the embedding that was used to calculate the similarity
Similarity float64 // the similarity between the embedding and the query
}
type Heap []EmbeddingSimilarity
func (h Heap) Len() int { return len(h) }
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *Heap) Push(e any) {
*h = append(*h, e.(EmbeddingSimilarity))
}
func (h *Heap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
// This value will range from -1 to 1, where 1 means the vectors are identical.
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
dotProduct := mat.Dot(vec1, vec2)
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
if norms == 0 {
return 0
}
return dotProduct / norms
}
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
h := &Heap{}
heap.Init(h)
for _, emb := range embeddings {
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
if h.Len() > k {
heap.Pop(h)
}
}
topK := make([]EmbeddingSimilarity, 0, h.Len())
for h.Len() > 0 {
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
}
sort.Slice(topK, func(i, j int) bool {
return topK[i].Similarity > topK[j].Similarity
})
return topK
}