diff --git a/server/images.go b/server/images.go index b807e11b..fe41c9be 100644 --- a/server/images.go +++ b/server/images.go @@ -268,7 +268,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(ParseModelPath(modelFile)) + mf, err = GetManifest(ParseModelPath(c.Args)) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } @@ -354,6 +354,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e embed.opts.FromMap(formattedParams) } + fmt.Println(embed.model) + // generate the embedding layers embeddingLayers, err := embeddingLayers(embed) if err != nil { @@ -406,13 +408,21 @@ type EmbeddingParams struct { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { layers := []*LayerReader{} if len(e.files) > 0 { - model, err := GetModel(e.model) - if err != nil { - return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + if _, err := os.Stat(e.model); err != nil { + if os.IsNotExist(err) { + // this is a model name rather than the file + model, err := GetModel(e.model) + if err != nil { + return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) + } + e.model = model.ModelPath + } else { + return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err) + } } e.opts.EmbeddingOnly = true - llm, err := llama.New(model.ModelPath, e.opts) + llm, err := llama.New(e.model, e.opts) if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } @@ -475,7 +485,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { log.Printf("reloading model, embedding contains NaN or Inf") // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it llm.Close() - llm, err = llama.New(model.ModelPath, e.opts) + llm, err = llama.New(e.model, e.opts) if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) }