allow embedding from model binary

This commit is contained in:
Bruce MacDonald 2023-08-08 14:38:57 -04:00
parent 3ceac05108
commit 884d78ceb3

View file

@ -268,7 +268,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, err = GetManifest(ParseModelPath(modelFile)) mf, err = GetManifest(ParseModelPath(c.Args))
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err) return fmt.Errorf("failed to open file after pull: %v", err)
} }
@ -354,6 +354,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
embed.opts.FromMap(formattedParams) embed.opts.FromMap(formattedParams)
} }
fmt.Println(embed.model)
// generate the embedding layers // generate the embedding layers
embeddingLayers, err := embeddingLayers(embed) embeddingLayers, err := embeddingLayers(embed)
if err != nil { if err != nil {
@ -406,13 +408,21 @@ type EmbeddingParams struct {
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{} layers := []*LayerReader{}
if len(e.files) > 0 { if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil {
if os.IsNotExist(err) {
// this is a model name rather than the file
model, err := GetModel(e.model) model, err := GetModel(e.model)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err) return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
} }
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
}
}
e.opts.EmbeddingOnly = true e.opts.EmbeddingOnly = true
llm, err := llama.New(model.ModelPath, e.opts) llm, err := llama.New(e.model, e.opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }
@ -475,7 +485,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
log.Printf("reloading model, embedding contains NaN or Inf") log.Printf("reloading model, embedding contains NaN or Inf")
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it // reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
llm.Close() llm.Close()
llm, err = llama.New(model.ModelPath, e.opts) llm, err = llama.New(e.model, e.opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }