Merge pull request #898 from jmorganca/mxyng/build-context

create remote models
This commit is contained in:
Michael Yang 2023-11-15 16:41:12 -08:00 committed by GitHub
commit 77954bea0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 382 additions and 205 deletions

View file

@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
var reqBody io.Reader var reqBody io.Reader
var data []byte var data []byte
var err error var err error
if reqData != nil {
switch reqData := reqData.(type) {
case io.Reader:
// reqData is already an io.Reader
reqBody = reqData
case nil:
// noop
default:
data, err = json.Marshal(reqData) data, err = json.Marshal(reqData)
if err != nil { if err != nil {
return err return err
} }
reqBody = bytes.NewReader(data) reqBody = bytes.NewReader(data)
} }
@ -287,3 +296,18 @@ func (c *Client) Heartbeat(ctx context.Context) error {
} }
return nil return nil
} }
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
var statusError StatusError
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
return err
}
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
return err
}
}
return nil
}

View file

@ -101,6 +101,7 @@ type EmbeddingResponse struct {
type CreateRequest struct { type CreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
Path string `json:"path"` Path string `json:"path"`
Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
} }

View file

@ -1,9 +1,11 @@
package cmd package cmd
import ( import (
"bytes"
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@ -27,6 +29,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/progressbar" "github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/readline"
"github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/server"
@ -45,17 +48,64 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
var spinner *Spinner modelfile, err := os.ReadFile(filename)
if err != nil {
return err
}
spinner := NewSpinner("transferring context")
go spinner.Spin(100 * time.Millisecond)
commands, err := parser.Parse(bytes.NewReader(modelfile))
if err != nil {
return err
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
for _, c := range commands {
switch c.Name {
case "model", "adapter":
path := c.Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
bin, err := os.Open(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
continue
} else if err != nil {
return err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return err
}
bin.Seek(0, io.SeekStart)
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
}
}
var currentDigest string var currentDigest string
var bar *progressbar.ProgressBar var bar *progressbar.ProgressBar
request := api.CreateRequest{Name: args[0], Path: filename} request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentDigest && resp.Digest != "" { if resp.Digest != currentDigest && resp.Digest != "" {
if spinner != nil {
spinner.Stop() spinner.Stop()
}
currentDigest = resp.Digest currentDigest = resp.Digest
// pulling // pulling
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
@ -67,9 +117,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
bar.Set64(resp.Completed) bar.Set64(resp.Completed)
} else { } else {
currentDigest = "" currentDigest = ""
if spinner != nil {
spinner.Stop() spinner.Stop()
}
spinner = NewSpinner(resp.Status) spinner = NewSpinner(resp.Status)
go spinner.Spin(100 * time.Millisecond) go spinner.Spin(100 * time.Millisecond)
} }
@ -81,12 +129,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
if spinner != nil {
spinner.Stop() spinner.Stop()
if spinner.description != "success" { if spinner.description != "success" {
return errors.New("unexpected end to create model") return errors.New("unexpected end to create model")
} }
}
return nil return nil
} }

View file

@ -292,12 +292,13 @@ curl -X POST http://localhost:11434/api/generate -d '{
POST /api/create POST /api/create
``` ```
Create a model from a [`Modelfile`](./modelfile.md) Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation should also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response.
### Parameters ### Parameters
- `name`: name of the model to create - `name`: name of the model to create
- `path`: path to the Modelfile - `path`: path to the Modelfile (deprecated: please use modelfile instead)
- `modelfile`: contents of the Modelfile
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples ### Examples
@ -307,7 +308,8 @@ Create a model from a [`Modelfile`](./modelfile.md)
```shell ```shell
curl -X POST http://localhost:11434/api/create -d '{ curl -X POST http://localhost:11434/api/create -d '{
"name": "mario", "name": "mario",
"path": "~/Modelfile" "path": "~/Modelfile",
"modelfile": "FROM llama2"
}' }'
``` ```
@ -321,6 +323,54 @@ A stream of JSON objects. When finished, `status` is `success`.
} }
``` ```
### Check if a Blob Exists
```shell
HEAD /api/blobs/:digest
```
Check if a blob is known to the server.
#### Query Parameters
- `digest`: the SHA256 digest of the blob
#### Examples
##### Request
```shell
curl -I http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2
```
##### Response
Return 200 OK if the blob exists, 404 Not Found if it does not.
### Create a Blob
```shell
POST /api/blobs/:digest
```
Create a blob from a file. Returns the server file path.
#### Query Parameters
- `digest`: the expected SHA256 digest of the file
#### Examples
##### Request
```shell
curl -T model.bin -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2
```
##### Response
Return 201 Created if the blob was successfully created.
## List Local Models ## List Local Models
```shell ```shell

View file

@ -248,88 +248,131 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil return f, nil
} }
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { func realpath(p string) string {
mp := ParseModelPath(name) abspath, err := filepath.Abs(p)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
}
}
mf, err := os.Open(path)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) return p
return fmt.Errorf("failed to open file: %w", err)
} }
defer mf.Close()
fn(api.ProgressResponse{Status: "parsing modelfile"}) home, err := os.UserHomeDir()
commands, err := parser.Parse(mf)
if err != nil { if err != nil {
return err return abspath
} }
if p == "~" {
return home
} else if strings.HasPrefix(p, "~/") {
return filepath.Join(home, p[2:])
}
return abspath
}
func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
config := ConfigV2{ config := ConfigV2{
Architecture: "amd64",
OS: "linux", OS: "linux",
Architecture: "amd64",
} }
deleteMap := make(map[string]struct{})
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
var sourceParams map[string]any fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp)
if err != nil {
modelFile, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
return err return err
} }
if _, err := os.Stat(modelFile); err != nil {
// the model file does not exist, try pulling it c.Args = blobPath
if errors.Is(err, os.ErrNotExist) { }
fn(api.ProgressResponse{Status: "pulling model file"})
bin, err := os.Open(realpath(c.Args))
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, _, err = GetManifest(mp)
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err return err
} }
} else { case err != nil:
// create a model from this specified file return err
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
} }
defer file.Close()
ggml, err := llm.DecodeGGML(file) fn(api.ProgressResponse{Status: "reading model metadata"})
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return err
}
fromConfigFile, err := os.Open(fromConfigPath)
if err != nil {
return err
}
defer fromConfigFile.Close()
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err
}
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
fromParamsFile, err := os.Open(fromParamsPath)
if err != nil {
return err
}
defer fromParamsFile.Close()
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err
}
}
layer, err := GetLayerWithBufferFromLayer(layer)
if err != nil {
return err
}
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil { if err != nil {
return err return err
} }
@ -339,109 +382,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
config.ModelType = ggml.ModelType() config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType() config.FileType = ggml.FileType()
// reset the file bin.Seek(0, io.SeekStart)
file.Seek(0, io.SeekStart) layer, err := CreateLayer(bin)
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.model"
layers = append(layers, l)
}
}
if mf != nil {
fn(api.ProgressResponse{Status: "reading model metadata"})
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceBlob, err := os.Open(sourceBlobPath) layer.MediaType = mediatype
if err != nil { layers = append(layers, layer)
return err
}
defer sourceBlob.Close()
var source ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
return err
}
// copy the model metadata
config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType
config.ModelFormat = source.ModelFormat
config.FileType = source.FileType
for _, l := range mf.Layers {
if l.MediaType == "application/vnd.ollama.image.params" {
sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
if err != nil {
return err
}
defer sourceParamsBlob.Close()
if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
return err
}
}
newLayer, err := GetLayerWithBufferFromLayer(l)
if err != nil {
return err
}
newLayer.From = mp.GetShortTagname()
layers = append(layers, newLayer)
}
}
case "adapter": case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(c.Args))
fp, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
return err return err
} }
defer bin.Close()
// create a model from this specified file layer, err := CreateLayer(bin)
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(fp)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.adapter"
layers = append(layers, l)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
return err return err
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
case "template", "system", "prompt": case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: "creating license layer"})
// remove the layer if one exists layer, err := CreateLayer(strings.NewReader(c.Args))
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) if err != nil {
layers = removeLayerFromLayers(layers, mediaType) return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicate layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
@ -449,48 +430,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
default: default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
// Create a single layer for the parameters
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
formattedParams, err := formatParams(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return err
} }
for k, v := range sourceParams { for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok { if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v formattedParams[k] = v
} }
} }
if config.ModelType == "65B" { if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B" config.ModelType = "70B"
} }
} }
bts, err := json.Marshal(formattedParams) var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
return err
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
if err != nil { if err != nil {
return err return err
} }
l, err := CreateLayer(bytes.NewReader(bts)) layer.MediaType = "application/vnd.ollama.image.params"
if err != nil { layers = append(layers, layer)
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l)
} }
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
@ -498,36 +478,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
var manifestLayers []*Layer configLayer, err := createConfigLayer(config, digests)
for _, l := range layers {
manifestLayers = append(manifestLayers, &l.Layer)
delete(deleteMap, l.Layer.Digest)
}
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(config, digests)
if err != nil { if err != nil {
return err return err
} }
layers = append(layers, cfg)
delete(deleteMap, cfg.Layer.Digest) layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil { if err := SaveLayers(layers, fn, false); err != nil {
return err return err
} }
// Create the manifest var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers) if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err != nil {
return err return err
} }
if noprune == "" { if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"}) if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err return err
} }
} }
@ -739,7 +714,7 @@ func CopyModel(src, dest string) error {
return nil return nil
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
fp, err := GetManifestPath() fp, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
@ -779,8 +754,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
// only delete the files which are still in the deleteMap // only delete the files which are still in the deleteMap
for k, v := range deleteMap { for k := range deleteMap {
if v {
fp, err := GetBlobsPath(k) fp, err := GetBlobsPath(k)
if err != nil { if err != nil {
log.Printf("couldn't get file path for '%s': %v", k, err) log.Printf("couldn't get file path for '%s': %v", k, err)
@ -795,13 +769,12 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
log.Printf("wanted to remove: %s", fp) log.Printf("wanted to remove: %s", fp)
} }
} }
}
return nil return nil
} }
func PruneLayers() error { func PruneLayers() error {
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("") p, err := GetBlobsPath("")
if err != nil { if err != nil {
return err return err
@ -818,7 +791,9 @@ func PruneLayers() error {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
name = strings.ReplaceAll(name, "-", ":") name = strings.ReplaceAll(name, "-", ":")
} }
deleteMap[name] = true if strings.HasPrefix(name, "sha256:") {
deleteMap[name] = struct{}{}
}
} }
log.Printf("total blobs: %d", len(deleteMap)) log.Printf("total blobs: %d", len(deleteMap))
@ -873,11 +848,11 @@ func DeleteModel(name string) error {
return err return err
} }
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true deleteMap[layer.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false) err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil { if err != nil {
@ -1013,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var noprune string var noprune string
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
@ -1023,9 +998,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if manifest != nil { if manifest != nil {
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = true deleteMap[l.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
} }
} }

View file

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -26,6 +27,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -409,8 +411,31 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" || req.Path == "" { if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}
var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
bin, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer bin.Close()
modelfile = bin
}
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -424,7 +449,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -625,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
func HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
hash := sha256.New()
temp, err := os.CreateTemp("", c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetPath, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated)
}
var defaultAllowOrigins = []string{ var defaultAllowOrigins = []string{
"localhost", "localhost",
"127.0.0.1", "127.0.0.1",
@ -684,6 +763,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/copy", CopyModelHandler) r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler) r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler) r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {