client create modelfile
This commit is contained in:
parent
3ca56b5ada
commit
1552cee59f
4 changed files with 147 additions and 14 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||||
var reqBody io.Reader
|
var reqBody io.Reader
|
||||||
var data []byte
|
var data []byte
|
||||||
var err error
|
var err error
|
||||||
if reqData != nil {
|
|
||||||
|
switch reqData := reqData.(type) {
|
||||||
|
case io.Reader:
|
||||||
|
// reqData is already an io.Reader
|
||||||
|
reqBody = reqData
|
||||||
|
case nil:
|
||||||
|
// noop
|
||||||
|
default:
|
||||||
data, err = json.Marshal(reqData)
|
data, err = json.Marshal(reqData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqBody = bytes.NewReader(data)
|
reqBody = bytes.NewReader(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
|
||||||
|
var response CreateBlobResponse
|
||||||
|
if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
|
||||||
|
var statusError StatusError
|
||||||
|
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.Path, nil
|
||||||
|
}
|
||||||
|
|
|
@ -105,6 +105,10 @@ type CreateRequest struct {
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateBlobResponse struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
type DeleteRequest struct {
|
type DeleteRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
63
cmd/cmd.go
63
cmd/cmd.go
|
@ -1,9 +1,11 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -27,6 +29,7 @@ import (
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/format"
|
"github.com/jmorganca/ollama/format"
|
||||||
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/progressbar"
|
"github.com/jmorganca/ollama/progressbar"
|
||||||
"github.com/jmorganca/ollama/readline"
|
"github.com/jmorganca/ollama/readline"
|
||||||
"github.com/jmorganca/ollama/server"
|
"github.com/jmorganca/ollama/server"
|
||||||
|
@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var spinner *Spinner
|
modelfile, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
spinner := NewSpinner("transferring context")
|
||||||
|
go spinner.Spin(100 * time.Millisecond)
|
||||||
|
|
||||||
|
commands, err := parser.Parse(bytes.NewReader(modelfile))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range commands {
|
||||||
|
switch c.Name {
|
||||||
|
case "model", "adapter":
|
||||||
|
path := c.Args
|
||||||
|
if path == "~" {
|
||||||
|
path = home
|
||||||
|
} else if strings.HasPrefix(path, "~/") {
|
||||||
|
path = filepath.Join(home, path[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := os.Open(path)
|
||||||
|
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
|
||||||
|
// value might be a model reference and not a real file
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer bin.Close()
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
if _, err := io.Copy(hash, bin); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bin.Seek(0, io.SeekStart)
|
||||||
|
|
||||||
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
|
path, err = client.CreateBlob(cmd.Context(), digest, bin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var currentDigest string
|
var currentDigest string
|
||||||
var bar *progressbar.ProgressBar
|
var bar *progressbar.ProgressBar
|
||||||
|
|
||||||
request := api.CreateRequest{Name: args[0], Path: filename}
|
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||||
if spinner != nil {
|
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
// pulling
|
// pulling
|
||||||
bar = progressbar.DefaultBytes(
|
bar = progressbar.DefaultBytes(
|
||||||
|
@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
bar.Set64(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
currentDigest = ""
|
currentDigest = ""
|
||||||
if spinner != nil {
|
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
|
||||||
spinner = NewSpinner(resp.Status)
|
spinner = NewSpinner(resp.Status)
|
||||||
go spinner.Spin(100 * time.Millisecond)
|
go spinner.Spin(100 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
@ -81,12 +130,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if spinner != nil {
|
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
if spinner.description != "success" {
|
if spinner.description != "success" {
|
||||||
return errors.New("unexpected end to create model")
|
return errors.New("unexpected end to create model")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetBlobHandler(c *gin.Context) {
|
||||||
|
path, err := GetBlobsPath(c.Param("digest"))
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path})
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBlobHandler(c *gin.Context) {
|
||||||
|
hash := sha256.New()
|
||||||
|
temp, err := os.CreateTemp("", c.Param("digest"))
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer temp.Close()
|
||||||
|
defer os.Remove(temp.Name())
|
||||||
|
|
||||||
|
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := temp.Close(); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetPath, err := GetBlobsPath(c.Param("digest"))
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(temp.Name(), targetPath); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath})
|
||||||
|
}
|
||||||
|
|
||||||
var defaultAllowOrigins = []string{
|
var defaultAllowOrigins = []string{
|
||||||
"localhost",
|
"localhost",
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
|
@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
r.POST("/api/copy", CopyModelHandler)
|
r.POST("/api/copy", CopyModelHandler)
|
||||||
r.DELETE("/api/delete", DeleteModelHandler)
|
r.DELETE("/api/delete", DeleteModelHandler)
|
||||||
r.POST("/api/show", ShowModelHandler)
|
r.POST("/api/show", ShowModelHandler)
|
||||||
|
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
|
@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
r.Handle(method, "/api/tags", ListModelsHandler)
|
||||||
|
r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||||
|
|
Loading…
Add table
Reference in a new issue