download models when creating from modelfile
This commit is contained in:
parent
12ab8f8f5f
commit
4c1caa3733
5 changed files with 73 additions and 41 deletions
|
@ -189,11 +189,11 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateProgressFunc func(CreateProgress) error
|
type CreateProgressFunc func(ProgressResponse) error
|
||||||
|
|
||||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||||
var resp CreateProgress
|
var resp ProgressResponse
|
||||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,10 +40,6 @@ type CreateRequest struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateProgress struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DeleteRequest struct {
|
type DeleteRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
21
cmd/cmd.go
21
cmd/cmd.go
|
@ -36,15 +36,32 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
var spinner *Spinner
|
var spinner *Spinner
|
||||||
|
|
||||||
|
var currentDigest string
|
||||||
|
var bar *progressbar.ProgressBar
|
||||||
|
|
||||||
request := api.CreateRequest{Name: args[0], Path: filename}
|
request := api.CreateRequest{Name: args[0], Path: filename}
|
||||||
fn := func(resp api.CreateProgress) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
|
currentDigest = resp.Digest
|
||||||
|
bar = progressbar.DefaultBytes(
|
||||||
|
int64(resp.Total),
|
||||||
|
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
|
||||||
|
)
|
||||||
|
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else {
|
||||||
|
currentDigest = ""
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
spinner = NewSpinner(resp.Status)
|
spinner = NewSpinner(resp.Status)
|
||||||
go spinner.Spin(100 * time.Millisecond)
|
go spinner.Spin(100 * time.Millisecond)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,15 +187,15 @@ func GetModel(name string) (*Model, error) {
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModel(name string, path string, fn func(status string)) error {
|
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||||
mf, err := os.Open(path)
|
mf, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fn(fmt.Sprintf("couldn't open modelfile '%s'", path))
|
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
|
||||||
return fmt.Errorf("failed to open file: %w", err)
|
return fmt.Errorf("failed to open file: %w", err)
|
||||||
}
|
}
|
||||||
defer mf.Close()
|
defer mf.Close()
|
||||||
|
|
||||||
fn("parsing modelfile")
|
fn(api.ProgressResponse{Status: "parsing modelfile"})
|
||||||
commands, err := parser.Parse(mf)
|
commands, err := parser.Parse(mf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -208,7 +208,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
fn("looking for model")
|
fn(api.ProgressResponse{Status: "looking for model"})
|
||||||
mf, err := GetManifest(ParseModelPath(c.Args))
|
mf, err := GetManifest(ParseModelPath(c.Args))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp := c.Args
|
fp := c.Args
|
||||||
|
@ -229,7 +229,25 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
fp = filepath.Join(filepath.Dir(path), fp)
|
fp = filepath.Join(filepath.Dir(path), fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn("creating model layer")
|
if _, err := os.Stat(fp); err != nil {
|
||||||
|
// the model file does not exist, try pulling it
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
fn(api.ProgressResponse{Status: "pulling model file"})
|
||||||
|
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// create a model from this specified file
|
||||||
|
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||||
|
|
||||||
file, err := os.Open(fp)
|
file, err := os.Open(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open file: %v", err)
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
|
@ -242,7 +260,9 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
}
|
}
|
||||||
l.MediaType = "application/vnd.ollama.image.model"
|
l.MediaType = "application/vnd.ollama.image.model"
|
||||||
layers = append(layers, l)
|
layers = append(layers, l)
|
||||||
} else {
|
}
|
||||||
|
}
|
||||||
|
if mf != nil {
|
||||||
log.Printf("manifest = %#v", mf)
|
log.Printf("manifest = %#v", mf)
|
||||||
for _, l := range mf.Layers {
|
for _, l := range mf.Layers {
|
||||||
newLayer, err := GetLayerWithBufferFromLayer(l)
|
newLayer, err := GetLayerWithBufferFromLayer(l)
|
||||||
|
@ -253,7 +273,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "license", "template", "system", "prompt":
|
case "license", "template", "system", "prompt":
|
||||||
fn(fmt.Sprintf("creating %s layer", c.Name))
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
// remove the prompt layer if one exists
|
// remove the prompt layer if one exists
|
||||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
layers = removeLayerFromLayers(layers, mediaType)
|
layers = removeLayerFromLayers(layers, mediaType)
|
||||||
|
@ -272,7 +292,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
|
|
||||||
// Create a single layer for the parameters
|
// Create a single layer for the parameters
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
fn("creating parameter layer")
|
fn(api.ProgressResponse{Status: "creating parameter layer"})
|
||||||
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
|
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
|
||||||
paramData, err := paramsToReader(params)
|
paramData, err := paramsToReader(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -297,7 +317,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a layer for the config object
|
// Create a layer for the config object
|
||||||
fn("creating config layer")
|
fn(api.ProgressResponse{Status: "creating config layer"})
|
||||||
cfg, err := createConfigLayer(digests)
|
cfg, err := createConfigLayer(digests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -310,13 +330,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the manifest
|
// Create the manifest
|
||||||
fn("writing manifest")
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
err = CreateManifest(name, cfg, manifestLayers)
|
err = CreateManifest(name, cfg, manifestLayers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fn("success")
|
fn(api.ProgressResponse{Status: "success"})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,7 +351,7 @@ func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerRead
|
||||||
return layers[:j]
|
return layers[:j]
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error {
|
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
|
||||||
// Write each of the layers to disk
|
// Write each of the layers to disk
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
fp, err := GetBlobsPath(layer.Digest)
|
fp, err := GetBlobsPath(layer.Digest)
|
||||||
|
@ -341,7 +361,8 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
|
||||||
|
|
||||||
_, err = os.Stat(fp)
|
_, err = os.Stat(fp)
|
||||||
if os.IsNotExist(err) || force {
|
if os.IsNotExist(err) || force {
|
||||||
fn(fmt.Sprintf("writing layer %s", layer.Digest))
|
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
||||||
|
|
||||||
out, err := os.Create(fp)
|
out, err := os.Create(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't create %s", fp)
|
log.Printf("couldn't create %s", fp)
|
||||||
|
@ -354,7 +375,7 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
fn(fmt.Sprintf("using already created layer %s", layer.Digest))
|
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -147,10 +147,8 @@ func CreateModelHandler(c *gin.Context) {
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
fn := func(status string) {
|
fn := func(resp api.ProgressResponse) {
|
||||||
ch <- api.CreateProgress{
|
ch <- resp
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := CreateModel(req.Name, req.Path, fn); err != nil {
|
if err := CreateModel(req.Name, req.Path, fn); err != nil {
|
||||||
|
|
Loading…
Reference in a new issue