diff --git a/go.mod b/go.mod index ad78e5c0..7728b6fd 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.10.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/term v0.10.0 diff --git a/go.sum b/go.sum index eb48b178..b563494b 100644 --- a/go.sum +++ b/go.sum @@ -121,6 +121,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= diff --git a/server/images.go b/server/images.go index a3b3eb4c..cc284510 100644 --- a/server/images.go +++ b/server/images.go @@ -22,6 +22,8 @@ import ( "strings" "text/template" + "golang.org/x/exp/slices" + "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/parser" @@ -274,6 +276,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api var layers []*LayerReader params := make(map[string][]string) + var sourceParams map[string]any embed := EmbeddingParams{fn: fn} for _, c := range commands { log.Printf("[%s] - %s\n", c.Name, c.Args) @@ -357,6 +360,23 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api config.FileType = source.FileType for _, l := range mf.Layers { + if l.MediaType == "application/vnd.ollama.image.params" { + sourceParamsBlobPath, err := GetBlobsPath(l.Digest) + if err != nil { + return err + } + + sourceParamsBlob, err := os.Open(sourceParamsBlobPath) + if err != nil { + return err + } + defer sourceParamsBlob.Close() + + if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { + return err + } + } + newLayer, err := GetLayerWithBufferFromLayer(l) if err != nil { return err @@ -427,12 +447,19 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api // Create a single layer for the parameters if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameter layer"}) + layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") formattedParams, err := formatParams(params) if err != nil { return fmt.Errorf("couldn't create params json: %v", err) } + for k, v := range sourceParams { + if _, ok := formattedParams[k]; !ok { + formattedParams[k] = v + } + } + bts, err := json.Marshal(formattedParams) if err != nil { return err @@ -630,14 +657,9 @@ func existingFileEmbeddings(digest string) (map[string][]float64, error) { } func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader { - j := 0 - for _, l := range layers { - if l.MediaType != mediaType { - layers[j] = l - j++ - } - } - return layers[:j] + return slices.DeleteFunc(layers, func(layer *LayerReader) bool { + return layer.MediaType == mediaType + }) } func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {