Merge pull request #898 from jmorganca/mxyng/build-context

create remote models
This commit is contained in:
Michael Yang 2023-11-15 16:41:12 -08:00 committed by GitHub
commit 77954bea0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 382 additions and 205 deletions

View file

@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
var reqBody io.Reader
var data []byte
var err error
if reqData != nil {
switch reqData := reqData.(type) {
case io.Reader:
// reqData is already an io.Reader
reqBody = reqData
case nil:
// noop
default:
data, err = json.Marshal(reqData)
if err != nil {
return err
}
reqBody = bytes.NewReader(data)
}
@ -287,3 +296,18 @@ func (c *Client) Heartbeat(ctx context.Context) error {
}
return nil
}
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
var statusError StatusError
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
return err
}
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
return err
}
}
return nil
}

View file

@ -99,9 +99,10 @@ type EmbeddingResponse struct {
}
type CreateRequest struct {
Name string `json:"name"`
Path string `json:"path"`
Stream *bool `json:"stream,omitempty"`
Name string `json:"name"`
Path string `json:"path"`
Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"`
}
type DeleteRequest struct {

View file

@ -1,9 +1,11 @@
package cmd
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/pem"
"errors"
"fmt"
@ -27,6 +29,7 @@ import (
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/readline"
"github.com/jmorganca/ollama/server"
@ -45,17 +48,64 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
var spinner *Spinner
modelfile, err := os.ReadFile(filename)
if err != nil {
return err
}
spinner := NewSpinner("transferring context")
go spinner.Spin(100 * time.Millisecond)
commands, err := parser.Parse(bytes.NewReader(modelfile))
if err != nil {
return err
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
for _, c := range commands {
switch c.Name {
case "model", "adapter":
path := c.Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
bin, err := os.Open(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
continue
} else if err != nil {
return err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return err
}
bin.Seek(0, io.SeekStart)
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
}
}
var currentDigest string
var bar *progressbar.ProgressBar
request := api.CreateRequest{Name: args[0], Path: filename}
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentDigest && resp.Digest != "" {
if spinner != nil {
spinner.Stop()
}
spinner.Stop()
currentDigest = resp.Digest
// pulling
bar = progressbar.DefaultBytes(
@ -67,9 +117,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
bar.Set64(resp.Completed)
} else {
currentDigest = ""
if spinner != nil {
spinner.Stop()
}
spinner.Stop()
spinner = NewSpinner(resp.Status)
go spinner.Spin(100 * time.Millisecond)
}
@ -81,11 +129,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
if spinner != nil {
spinner.Stop()
if spinner.description != "success" {
return errors.New("unexpected end to create model")
}
spinner.Stop()
if spinner.description != "success" {
return errors.New("unexpected end to create model")
}
return nil

View file

@ -292,12 +292,13 @@ curl -X POST http://localhost:11434/api/generate -d '{
POST /api/create
```
Create a model from a [`Modelfile`](./modelfile.md)
Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation should also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response.
### Parameters
- `name`: name of the model to create
- `path`: path to the Modelfile
- `path`: path to the Modelfile (deprecated: please use modelfile instead)
- `modelfile`: contents of the Modelfile
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples
@ -307,7 +308,8 @@ Create a model from a [`Modelfile`](./modelfile.md)
```shell
curl -X POST http://localhost:11434/api/create -d '{
"name": "mario",
"path": "~/Modelfile"
"path": "~/Modelfile",
"modelfile": "FROM llama2"
}'
```
@ -321,6 +323,54 @@ A stream of JSON objects. When finished, `status` is `success`.
}
```
### Check if a Blob Exists
```shell
HEAD /api/blobs/:digest
```
Check if a blob is known to the server.
#### Query Parameters
- `digest`: the SHA256 digest of the blob
#### Examples
##### Request
```shell
curl -I http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2
```
##### Response
Return 200 OK if the blob exists, 404 Not Found if it does not.
### Create a Blob
```shell
POST /api/blobs/:digest
```
Create a blob from a file. Returns the server file path.
#### Query Parameters
- `digest`: the expected SHA256 digest of the file
#### Examples
##### Request
```shell
curl -T model.bin -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2
```
##### Response
Return 201 Created if the blob was successfully created.
## List Local Models
```shell

View file

@ -248,200 +248,181 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil
}
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
}
}
mf, err := os.Open(path)
func realpath(p string) string {
abspath, err := filepath.Abs(p)
if err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
return fmt.Errorf("failed to open file: %w", err)
return p
}
defer mf.Close()
fn(api.ProgressResponse{Status: "parsing modelfile"})
commands, err := parser.Parse(mf)
home, err := os.UserHomeDir()
if err != nil {
return err
return abspath
}
if p == "~" {
return home
} else if strings.HasPrefix(p, "~/") {
return filepath.Join(home, p[2:])
}
return abspath
}
func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
config := ConfigV2{
Architecture: "amd64",
OS: "linux",
Architecture: "amd64",
}
deleteMap := make(map[string]struct{})
var layers []*LayerReader
params := make(map[string][]string)
var sourceParams map[string]any
fromParams := make(map[string]any)
for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args)
log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name {
case "model":
fn(api.ProgressResponse{Status: "looking for model"})
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp)
c.Args = blobPath
}
bin, err := os.Open(realpath(c.Args))
if err != nil {
modelFile, err := filenameWithPath(path, c.Args)
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
manifest, _, err = GetManifest(modelpath)
if err != nil {
return err
}
case err != nil:
return err
}
if _, err := os.Stat(modelFile); err != nil {
// the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, _, err = GetManifest(mp)
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err
}
} else {
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ggml, err := llm.DecodeGGML(file)
if err != nil {
return err
}
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
// reset the file
file.Seek(0, io.SeekStart)
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.model"
layers = append(layers, l)
}
}
if mf != nil {
fn(api.ProgressResponse{Status: "reading model metadata"})
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return err
}
sourceBlob, err := os.Open(sourceBlobPath)
fromConfigFile, err := os.Open(fromConfigPath)
if err != nil {
return err
}
defer sourceBlob.Close()
defer fromConfigFile.Close()
var source ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err
}
// copy the model metadata
config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType
config.ModelFormat = source.ModelFormat
config.FileType = source.FileType
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
for _, l := range mf.Layers {
if l.MediaType == "application/vnd.ollama.image.params" {
sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
fromParamsFile, err := os.Open(fromParamsPath)
if err != nil {
return err
}
defer sourceParamsBlob.Close()
defer fromParamsFile.Close()
if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err
}
}
newLayer, err := GetLayerWithBufferFromLayer(l)
layer, err := GetLayerWithBufferFromLayer(layer)
if err != nil {
return err
}
newLayer.From = mp.GetShortTagname()
layers = append(layers, newLayer)
}
}
case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
fp, err := filenameWithPath(path, c.Args)
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil {
return err
}
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
file, err := os.Open(fp)
bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
return err
}
defer file.Close()
l, err := CreateLayer(file)
layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(c.Args))
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
return err
}
l.MediaType = "application/vnd.ollama.image.adapter"
layers = append(layers, l)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
defer bin.Close()
layer, err := CreateLayer(strings.NewReader(c.Args))
layer, err := CreateLayer(bin)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediaType
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "template", "system", "prompt":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layers = removeLayerFromLayers(layers, mediaType)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicate layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil {
@ -449,48 +430,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
}
if layer.Size > 0 {
layer.MediaType = mediaType
layer.MediaType = mediatype
layers = append(layers, layer)
}
default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args)
}
}
// Create a single layer for the parameters
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"})
fn(api.ProgressResponse{Status: "creating parameters layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
formattedParams, err := formatParams(params)
if err != nil {
return fmt.Errorf("couldn't create params json: %v", err)
return err
}
for k, v := range sourceParams {
for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v
}
}
if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B"
}
}
bts, err := json.Marshal(formattedParams)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
return err
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
l, err := CreateLayer(bytes.NewReader(bts))
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l)
layer.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, layer)
}
digests, err := getLayerDigests(layers)
@ -498,36 +478,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err
}
var manifestLayers []*Layer
for _, l := range layers {
manifestLayers = append(manifestLayers, &l.Layer)
delete(deleteMap, l.Layer.Digest)
}
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(config, digests)
configLayer, err := createConfigLayer(config, digests)
if err != nil {
return err
}
layers = append(layers, cfg)
delete(deleteMap, cfg.Layer.Digest)
layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil {
return err
}
// Create the manifest
var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers)
if err != nil {
if err := CreateManifest(name, configLayer, contentLayers); err != nil {
return err
}
if noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
return err
}
}
@ -739,7 +714,7 @@ func CopyModel(src, dest string) error {
return nil
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error {
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
fp, err := GetManifestPath()
if err != nil {
return err
@ -779,21 +754,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
}
// only delete the files which are still in the deleteMap
for k, v := range deleteMap {
if v {
fp, err := GetBlobsPath(k)
if err != nil {
log.Printf("couldn't get file path for '%s': %v", k, err)
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
log.Printf("couldn't get file path for '%s': %v", k, err)
continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
log.Printf("couldn't remove file '%s': %v", fp, err)
continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
log.Printf("couldn't remove file '%s': %v", fp, err)
continue
}
} else {
log.Printf("wanted to remove: %s", fp)
}
} else {
log.Printf("wanted to remove: %s", fp)
}
}
@ -801,7 +774,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
}
func PruneLayers() error {
deleteMap := make(map[string]bool)
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
@ -818,7 +791,9 @@ func PruneLayers() error {
if runtime.GOOS == "windows" {
name = strings.ReplaceAll(name, "-", ":")
}
deleteMap[name] = true
if strings.HasPrefix(name, "sha256:") {
deleteMap[name] = struct{}{}
}
}
log.Printf("total blobs: %d", len(deleteMap))
@ -873,11 +848,11 @@ func DeleteModel(name string) error {
return err
}
deleteMap := make(map[string]bool)
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = true
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil {
@ -1013,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
@ -1023,9 +998,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
deleteMap[l.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = true
deleteMap[manifest.Config.Digest] = struct{}{}
}
}

View file

@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@ -26,6 +27,7 @@ import (
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
@ -409,8 +411,31 @@ func CreateModelHandler(c *gin.Context) {
return
}
if req.Name == "" || req.Path == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}
var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
bin, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer bin.Close()
modelfile = bin
}
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@ -424,7 +449,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -625,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
}
}
func HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
hash := sha256.New()
temp, err := os.CreateTemp("", c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetPath, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated)
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
@ -684,6 +763,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {