allow embedding from model binary
This commit is contained in:
parent
3ceac05108
commit
884d78ceb3
1 changed files with 16 additions and 6 deletions
|
@ -268,7 +268,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
|||
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
mf, err = GetManifest(ParseModelPath(modelFile))
|
||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||
}
|
||||
|
@ -354,6 +354,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
|||
embed.opts.FromMap(formattedParams)
|
||||
}
|
||||
|
||||
fmt.Println(embed.model)
|
||||
|
||||
// generate the embedding layers
|
||||
embeddingLayers, err := embeddingLayers(embed)
|
||||
if err != nil {
|
||||
|
@ -406,13 +408,21 @@ type EmbeddingParams struct {
|
|||
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||
layers := []*LayerReader{}
|
||||
if len(e.files) > 0 {
|
||||
model, err := GetModel(e.model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
|
||||
if _, err := os.Stat(e.model); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// this is a model name rather than the file
|
||||
model, err := GetModel(e.model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
|
||||
}
|
||||
e.model = model.ModelPath
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
e.opts.EmbeddingOnly = true
|
||||
llm, err := llama.New(model.ModelPath, e.opts)
|
||||
llm, err := llama.New(e.model, e.opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||
}
|
||||
|
@ -475,7 +485,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
|||
log.Printf("reloading model, embedding contains NaN or Inf")
|
||||
// reload the model to get a new embedding, the seed can effect these outputs and reloading changes it
|
||||
llm.Close()
|
||||
llm, err = llama.New(model.ModelPath, e.opts)
|
||||
llm, err = llama.New(e.model, e.opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue