client create modelfile
This commit is contained in:
parent
3ca56b5ada
commit
1552cee59f
4 changed files with 147 additions and 14 deletions
|
@ -5,6 +5,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||
var reqBody io.Reader
|
||||
var data []byte
|
||||
var err error
|
||||
if reqData != nil {
|
||||
|
||||
switch reqData := reqData.(type) {
|
||||
case io.Reader:
|
||||
// reqData is already an io.Reader
|
||||
reqBody = reqData
|
||||
case nil:
|
||||
// noop
|
||||
default:
|
||||
data, err = json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqBody = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
|
@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
|
||||
var response CreateBlobResponse
|
||||
if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
|
||||
var statusError StatusError
|
||||
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return response.Path, nil
|
||||
}
|
||||
|
|
|
@ -105,6 +105,10 @@ type CreateRequest struct {
|
|||
Stream *bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type CreateBlobResponse struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
|
73
cmd/cmd.go
73
cmd/cmd.go
|
@ -1,9 +1,11 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -27,6 +29,7 @@ import (
|
|||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/progressbar"
|
||||
"github.com/jmorganca/ollama/readline"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
|
@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
var spinner *Spinner
|
||||
modelfile, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
spinner := NewSpinner("transferring context")
|
||||
go spinner.Spin(100 * time.Millisecond)
|
||||
|
||||
commands, err := parser.Parse(bytes.NewReader(modelfile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range commands {
|
||||
switch c.Name {
|
||||
case "model", "adapter":
|
||||
path := c.Args
|
||||
if path == "~" {
|
||||
path = home
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
path = filepath.Join(home, path[2:])
|
||||
}
|
||||
|
||||
bin, err := os.Open(path)
|
||||
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
|
||||
// value might be a model reference and not a real file
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return err
|
||||
}
|
||||
bin.Seek(0, io.SeekStart)
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
path, err = client.CreateBlob(cmd.Context(), digest, bin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path))
|
||||
}
|
||||
}
|
||||
|
||||
var currentDigest string
|
||||
var bar *progressbar.ProgressBar
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Path: filename}
|
||||
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
spinner.Stop()
|
||||
currentDigest = resp.Digest
|
||||
// pulling
|
||||
bar = progressbar.DefaultBytes(
|
||||
|
@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
bar.Set64(resp.Completed)
|
||||
} else {
|
||||
currentDigest = ""
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
spinner.Stop()
|
||||
spinner = NewSpinner(resp.Status)
|
||||
go spinner.Spin(100 * time.Millisecond)
|
||||
}
|
||||
|
@ -81,11 +130,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
if spinner.description != "success" {
|
||||
return errors.New("unexpected end to create model")
|
||||
}
|
||||
spinner.Stop()
|
||||
if spinner.description != "success" {
|
||||
return errors.New("unexpected end to create model")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func GetBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path})
|
||||
}
|
||||
|
||||
func CreateBlobHandler(c *gin.Context) {
|
||||
hash := sha256.New()
|
||||
temp, err := os.CreateTemp("", c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := temp.Close(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetPath, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Rename(temp.Name(), targetPath); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath})
|
||||
}
|
||||
|
||||
var defaultAllowOrigins = []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
|
@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||
r.POST("/api/copy", CopyModelHandler)
|
||||
r.DELETE("/api/delete", DeleteModelHandler)
|
||||
r.POST("/api/show", ShowModelHandler)
|
||||
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||
r.Handle(method, "/", func(c *gin.Context) {
|
||||
|
@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||
})
|
||||
|
||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
||||
r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
|
||||
}
|
||||
|
||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||
|
|
Loading…
Add table
Reference in a new issue