use loaded llm for embeddings

This commit is contained in:
Bruce MacDonald 2023-08-15 10:35:39 -03:00
parent 18f2cb0472
commit 326de48930
2 changed files with 17 additions and 25 deletions

View file

@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} embed := EmbeddingParams{fn: fn}
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
} else { } else {
embed.model = modelFile
// create a model from this specified file // create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"}) fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile) file, err := os.Open(modelFile)
@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, l) layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model // apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions() embed.opts = formattedParams
embed.opts.FromMap(formattedParams)
} }
// generate the embedding layers // generate the embedding layers
@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
type EmbeddingParams struct { type EmbeddingParams struct {
model string model string
opts api.Options opts map[string]interface{}
files []string // paths to files to embed files []string // paths to files to embed
fn func(resp api.ProgressResponse) fn func(resp api.ProgressResponse)
} }
@ -478,32 +478,22 @@ type EmbeddingParams struct {
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{} layers := []*LayerReader{}
if len(e.files) > 0 { if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil { // check if the model is a file path or a model name
if os.IsNotExist(err) {
// this is a model name rather than the file
model, err := GetModel(e.model) model, err := GetModel(e.model)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) if !strings.Contains(err.Error(), "couldn't open file") {
} return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
} }
// the model may be a file path, create a model from this file
model = &Model{ModelPath: e.model}
} }
e.opts.EmbeddingOnly = true if err := load(model, e.opts, defaultSessionDuration); err != nil {
llmModel, err := llm.New(e.model, []string{}, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }
defer func() {
if llmModel != nil {
llmModel.Close()
}
}()
// this will be used to check if we already have embeddings for a file // this will be used to check if we already have embeddings for a file
modelInfo, err := os.Stat(e.model) modelInfo, err := os.Stat(model.ModelPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get model file info: %v", err) return nil, fmt.Errorf("failed to get model file info: %v", err)
} }
@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]}) embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
continue continue
} }
embed, err := llmModel.Embedding(d) embed, err := loaded.llm.Embedding(d)
if err != nil { if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue continue

View file

@ -38,6 +38,8 @@ var loaded struct {
options api.Options options api.Options
} }
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := 5 * time.Minute sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
if err := load(model, req.Options, sessionDuration); err != nil { if err := load(model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return