embed text document in modelfile

This commit is contained in:
Bruce MacDonald 2023-08-09 10:26:19 -04:00 committed by GitHub
commit 7a5f3616fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 371 additions and 52 deletions

View file

@ -276,6 +276,7 @@ func DefaultOptions() Options {
UseMLock: false,
RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0,
EmbeddingOnly: true,
RepeatLastN: 64,
RepeatPenalty: 1.1,

View file

@ -48,12 +48,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop()
}
currentDigest = resp.Digest
switch {
case strings.Contains(resp.Status, "embeddings"):
bar = progressbar.Default(int64(resp.Total), resp.Status)
bar.Set(resp.Completed)
default:
// pulling
bar = progressbar.DefaultBytes(
int64(resp.Total),
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
resp.Status,
)
bar.Set(resp.Completed)
}
} else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed)
} else {

View file

@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama.
- [FROM (Required)](#from-required)
- [Build from llama2](#build-from-llama2)
- [Build from a bin file](#build-from-a-bin-file)
- [EMBED](#embed)
- [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values)
- [TEMPLATE](#template)
@ -88,6 +89,15 @@ FROM ./ollama-model.bin
This bin file location should be specified as an absolute path or relative to the Modelfile location.
### EMBED
The EMBED instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer.
```
FROM <model name>:<tag>
EMBED <file path>
```
### PARAMETER
The `PARAMETER` instruction defines a parameter that can be set when the model is run.

1
go.mod
View file

@ -42,6 +42,7 @@ require (
golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect
gonum.org/v1/gonum v0.13.0
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

2
go.sum
View file

@ -139,6 +139,8 @@ golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View file

@ -85,6 +85,7 @@ llama_token llama_sample(
}
*/
import "C"
import (
"bytes"
"embed"
@ -93,6 +94,7 @@ import (
"io"
"log"
"os"
"reflect"
"strings"
"sync"
"unicode/utf8"
@ -408,3 +410,38 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil
}
func (llm *LLM) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
tokens := llm.tokenize(input)
if tokens == nil {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
if n <= 0 {
return nil, errors.New("llama: no embeddings generated")
}
embedPtr := C.llama_get_embeddings(llm.ctx)
if embedPtr == nil {
return nil, errors.New("llama: embedding retrieval failed")
}
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
}

View file

@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(fields[1])
// copy command for validation
modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":

View file

@ -1,6 +1,7 @@
package server
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/json"
@ -9,6 +10,7 @@ import (
"html/template"
"io"
"log"
"math"
"net/http"
"os"
"path"
@ -18,7 +20,9 @@ import (
"strings"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
)
type RegistryOptions struct {
@ -34,9 +38,10 @@ type Model struct {
System string
Digest string
Options map[string]interface{}
Embeddings []vector.Embedding
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
@ -51,6 +56,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
First bool
System string
Prompt string
Embed string
// deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int
@ -60,6 +66,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
vars.System = m.System
vars.Prompt = request.Prompt
vars.Context = request.Context
vars.Embed = embedding
if request.System != "" {
vars.System = request.System
@ -157,6 +164,16 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
case "application/vnd.ollama.image.embed":
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open file: %s", filename)
}
defer file.Close()
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
return nil, err
}
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
@ -195,6 +212,26 @@ func GetModel(name string) (*Model, error) {
return model, nil
}
func filenameWithPath(path, f string) (string, error) {
// if filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(f, "~/") {
parts := strings.Split(f, "/")
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to open file: %v", err)
}
f = filepath.Join(home, filepath.Join(parts[1:]...))
}
// if filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(f) {
f = filepath.Join(filepath.Dir(path), f)
}
return f, nil
}
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
mf, err := os.Open(path)
if err != nil {
@ -211,33 +248,20 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
var layers []*LayerReader
params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name {
case "model":
fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mf, err := GetManifest(ParseModelPath(c.Args))
if err != nil {
fp := c.Args
// If filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(fp, "~/") {
parts := strings.Split(fp, "/")
home, err := os.UserHomeDir()
modelFile, err := filenameWithPath(path, c.Args)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
return err
}
fp = filepath.Join(home, filepath.Join(parts[1:]...))
}
// If filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(fp) {
fp = filepath.Join(filepath.Dir(path), fp)
}
if _, err := os.Stat(fp); err != nil {
if _, err := os.Stat(modelFile); err != nil {
// the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
@ -248,15 +272,13 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err
}
} else {
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(fp)
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
@ -280,9 +302,14 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
layers = append(layers, newLayer)
}
}
case "embed":
embedFilePath, err := filenameWithPath(path, c.Args)
if err != nil {
return err
}
embed.files = append(embed.files, embedFilePath)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args))
@ -315,18 +342,35 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
paramData, err := paramsToReader(params)
formattedParams, err := formatParams(params)
if err != nil {
return fmt.Errorf("couldn't create params json: %v", err)
}
l, err := CreateLayer(paramData)
bts, err := json.Marshal(formattedParams)
if err != nil {
return err
}
l, err := CreateLayer(bytes.NewReader(bts))
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions()
embed.opts.FromMap(formattedParams)
}
// generate the embedding layers
embeddingLayers, err := embeddingLayers(embed)
if err != nil {
return err
}
layers = append(layers, embeddingLayers...)
digests, err := getLayerDigests(layers)
if err != nil {
return err
@ -361,6 +405,138 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
return nil
}
type EmbeddingParams struct {
model string
opts api.Options
files []string // paths to files to embed
fn func(resp api.ProgressResponse)
}
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{}
if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil {
if os.IsNotExist(err) {
// this is a model name rather than the file
model, err := GetModel(e.model)
if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
}
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
}
}
e.opts.EmbeddingOnly = true
llm, err := llama.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
defer func() {
if llm != nil {
llm.Close()
}
}()
addedFiles := make(map[string]bool) // keep track of files that have already been added
for _, filePattern := range e.files {
matchingFiles, err := filepath.Glob(filePattern)
if err != nil {
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
}
for _, filePath := range matchingFiles {
if addedFiles[filePath] {
continue
}
addedFiles[filePath] = true
// TODO: check file type
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("could not open embed file: %w", err)
}
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanLines)
data := []string{}
for scanner.Scan() {
data = append(data, scanner.Text())
}
f.Close()
// the digest of the file is set here so that the client knows a new operation is in progress
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
embeddings := []vector.Embedding{}
for i, d := range data {
if strings.TrimSpace(d) == "" {
continue
}
e.fn(api.ProgressResponse{
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
Digest: fileDigest,
Total: len(data) - 1,
Completed: i,
})
retry := 0
generate:
if retry > 3 {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
}
embed, err := llm.Embedding(d)
if err != nil {
log.Printf("retrying embedding generation for '%s' line %d: %v", filePath, i+1, err)
retry++
goto generate
}
// Check for NaN and Inf in the embedding, which can't be stored
for _, value := range embed {
if math.IsNaN(value) || math.IsInf(value, 0) {
log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
llm.Close()
llm, err = llama.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
retry++
goto generate
}
}
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
}
b, err := json.Marshal(embeddings)
if err != nil {
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
}
r := bytes.NewReader(b)
digest, size := GetSHA256Digest(r)
// Reset the position of the reader after calculating the digest
if _, err := r.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("could not reset embed reader: %w", err)
}
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.ollama.image.embed",
Digest: digest,
Size: size,
},
Reader: r,
}
layers = append(layers, layer)
}
}
}
return layers, nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
j := 0
for _, l := range layers {
@ -449,8 +625,8 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil
}
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@ -504,12 +680,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
}
}
bts, err := json.Marshal(out)
if err != nil {
return nil, err
}
return bytes.NewReader(bts), nil
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) {
@ -1042,7 +1213,7 @@ func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func
for {
fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", digest),
Status: fmt.Sprintf("pulling %s...", digest[7:19]),
Digest: digest,
Total: int(total),
Completed: int(completed),

View file

@ -17,15 +17,18 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/vector"
)
var loaded struct {
mu sync.Mutex
llm *llama.LLM
Embeddings []vector.Embedding
expireAt time.Time
expireTimer *time.Timer
@ -72,6 +75,11 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = ""
}
if model.Embeddings != nil && len(model.Embeddings) > 0 {
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
loaded.Embeddings = model.Embeddings
}
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -101,7 +109,6 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = model.Digest
loaded.options = opts
}
sessionDuration := 5 * time.Minute
loaded.expireAt = time.Now().Add(sessionDuration)
@ -127,7 +134,22 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt, err := model.Prompt(req)
embedding := ""
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
for _, e := range topK {
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
}
}
prompt, err := model.Prompt(req, embedding)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

69
vector/store.go Normal file
View file

@ -0,0 +1,69 @@
package vector
import (
"container/heap"
"sort"
"gonum.org/v1/gonum/mat"
)
type Embedding struct {
Vector []float64 // the embedding vector
Data string // the data represted by the embedding
}
type EmbeddingSimilarity struct {
Embedding Embedding // the embedding that was used to calculate the similarity
Similarity float64 // the similarity between the embedding and the query
}
type Heap []EmbeddingSimilarity
func (h Heap) Len() int { return len(h) }
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *Heap) Push(e any) {
*h = append(*h, e.(EmbeddingSimilarity))
}
func (h *Heap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
// This value will range from -1 to 1, where 1 means the vectors are identical.
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
dotProduct := mat.Dot(vec1, vec2)
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
if norms == 0 {
return 0
}
return dotProduct / norms
}
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
h := &Heap{}
heap.Init(h)
for _, emb := range embeddings {
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
if h.Len() > k {
heap.Pop(h)
}
}
topK := make([]EmbeddingSimilarity, 0, h.Len())
for h.Len() > 0 {
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
}
sort.Slice(topK, func(i, j int) bool {
return topK[i].Similarity > topK[j].Similarity
})
return topK
}