client create modelfile

This commit is contained in:
Michael Yang 2023-11-14 14:07:40 -08:00
parent 3ca56b5ada
commit 1552cee59f
4 changed files with 147 additions and 14 deletions

View file

@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
var reqBody io.Reader
var data []byte
var err error
if reqData != nil {
switch reqData := reqData.(type) {
case io.Reader:
// reqData is already an io.Reader
reqBody = reqData
case nil:
// noop
default:
data, err = json.Marshal(reqData)
if err != nil {
return err
}
reqBody = bytes.NewReader(data)
}
@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error {
}
return nil
}
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
var response CreateBlobResponse
if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
var statusError StatusError
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
return "", err
}
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
return "", err
}
}
return response.Path, nil
}

View file

@ -105,6 +105,10 @@ type CreateRequest struct {
Stream *bool `json:"stream,omitempty"`
}
type CreateBlobResponse struct {
Path string `json:"path"`
}
type DeleteRequest struct {
Name string `json:"name"`
}

View file

@ -1,9 +1,11 @@
package cmd
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/pem"
"errors"
"fmt"
@ -27,6 +29,7 @@ import (
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/readline"
"github.com/jmorganca/ollama/server"
@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
var spinner *Spinner
modelfile, err := os.ReadFile(filename)
if err != nil {
return err
}
spinner := NewSpinner("transferring context")
go spinner.Spin(100 * time.Millisecond)
commands, err := parser.Parse(bytes.NewReader(modelfile))
if err != nil {
return err
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
for _, c := range commands {
switch c.Name {
case "model", "adapter":
path := c.Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
bin, err := os.Open(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
// value might be a model reference and not a real file
} else if err != nil {
return err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return err
}
bin.Seek(0, io.SeekStart)
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
path, err = client.CreateBlob(cmd.Context(), digest, bin)
if err != nil {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path))
}
}
var currentDigest string
var bar *progressbar.ProgressBar
request := api.CreateRequest{Name: args[0], Path: filename}
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentDigest && resp.Digest != "" {
if spinner != nil {
spinner.Stop()
}
spinner.Stop()
currentDigest = resp.Digest
// pulling
bar = progressbar.DefaultBytes(
@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
bar.Set64(resp.Completed)
} else {
currentDigest = ""
if spinner != nil {
spinner.Stop()
}
spinner.Stop()
spinner = NewSpinner(resp.Status)
go spinner.Spin(100 * time.Millisecond)
}
@ -81,11 +130,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
if spinner != nil {
spinner.Stop()
if spinner.description != "success" {
return errors.New("unexpected end to create model")
}
spinner.Stop()
if spinner.description != "success" {
return errors.New("unexpected end to create model")
}
return nil

View file

@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
}
}
func GetBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path})
}
func CreateBlobHandler(c *gin.Context) {
hash := sha256.New()
temp, err := os.CreateTemp("", c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetPath, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath})
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
}
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)