download models when creating from modelfile

This commit is contained in:
Bruce MacDonald 2023-07-25 14:25:13 -04:00
parent 12ab8f8f5f
commit 4c1caa3733
5 changed files with 73 additions and 41 deletions

View file

@ -189,11 +189,11 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
}) })
} }
type CreateProgressFunc func(CreateProgress) error type CreateProgressFunc func(ProgressResponse) error
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp CreateProgress var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil { if err := json.Unmarshal(bts, &resp); err != nil {
return err return err
} }

View file

@ -40,10 +40,6 @@ type CreateRequest struct {
Path string `json:"path"` Path string `json:"path"`
} }
type CreateProgress struct {
Status string `json:"status"`
}
type DeleteRequest struct { type DeleteRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }

View file

@ -36,15 +36,32 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
var spinner *Spinner var spinner *Spinner
var currentDigest string
var bar *progressbar.ProgressBar
request := api.CreateRequest{Name: args[0], Path: filename} request := api.CreateRequest{Name: args[0], Path: filename}
fn := func(resp api.CreateProgress) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentDigest && resp.Digest != "" {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
currentDigest = resp.Digest
bar = progressbar.DefaultBytes(
int64(resp.Total),
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
)
bar.Set(resp.Completed)
} else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed)
} else {
currentDigest = ""
if spinner != nil {
spinner.Stop()
}
spinner = NewSpinner(resp.Status) spinner = NewSpinner(resp.Status)
go spinner.Spin(100 * time.Millisecond) go spinner.Spin(100 * time.Millisecond)
}
return nil return nil
} }

View file

@ -187,15 +187,15 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func CreateModel(name string, path string, fn func(status string)) error { func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
mf, err := os.Open(path) mf, err := os.Open(path)
if err != nil { if err != nil {
fn(fmt.Sprintf("couldn't open modelfile '%s'", path)) fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
return fmt.Errorf("failed to open file: %w", err) return fmt.Errorf("failed to open file: %w", err)
} }
defer mf.Close() defer mf.Close()
fn("parsing modelfile") fn(api.ProgressResponse{Status: "parsing modelfile"})
commands, err := parser.Parse(mf) commands, err := parser.Parse(mf)
if err != nil { if err != nil {
return err return err
@ -208,7 +208,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
case "model": case "model":
fn("looking for model") fn(api.ProgressResponse{Status: "looking for model"})
mf, err := GetManifest(ParseModelPath(c.Args)) mf, err := GetManifest(ParseModelPath(c.Args))
if err != nil { if err != nil {
fp := c.Args fp := c.Args
@ -229,7 +229,25 @@ func CreateModel(name string, path string, fn func(status string)) error {
fp = filepath.Join(filepath.Dir(path), fp) fp = filepath.Join(filepath.Dir(path), fp)
} }
fn("creating model layer") if _, err := os.Stat(fp); err != nil {
// the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, err = GetManifest(ParseModelPath(c.Args))
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err
}
} else {
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(fp) file, err := os.Open(fp)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return fmt.Errorf("failed to open file: %v", err)
@ -242,7 +260,9 @@ func CreateModel(name string, path string, fn func(status string)) error {
} }
l.MediaType = "application/vnd.ollama.image.model" l.MediaType = "application/vnd.ollama.image.model"
layers = append(layers, l) layers = append(layers, l)
} else { }
}
if mf != nil {
log.Printf("manifest = %#v", mf) log.Printf("manifest = %#v", mf)
for _, l := range mf.Layers { for _, l := range mf.Layers {
newLayer, err := GetLayerWithBufferFromLayer(l) newLayer, err := GetLayerWithBufferFromLayer(l)
@ -253,7 +273,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
} }
} }
case "license", "template", "system", "prompt": case "license", "template", "system", "prompt":
fn(fmt.Sprintf("creating %s layer", c.Name)) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists // remove the prompt layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layers = removeLayerFromLayers(layers, mediaType) layers = removeLayerFromLayers(layers, mediaType)
@ -272,7 +292,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
// Create a single layer for the parameters // Create a single layer for the parameters
if len(params) > 0 { if len(params) > 0 {
fn("creating parameter layer") fn(api.ProgressResponse{Status: "creating parameter layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
paramData, err := paramsToReader(params) paramData, err := paramsToReader(params)
if err != nil { if err != nil {
@ -297,7 +317,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
} }
// Create a layer for the config object // Create a layer for the config object
fn("creating config layer") fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(digests) cfg, err := createConfigLayer(digests)
if err != nil { if err != nil {
return err return err
@ -310,13 +330,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
} }
// Create the manifest // Create the manifest
fn("writing manifest") fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers) err = CreateManifest(name, cfg, manifestLayers)
if err != nil { if err != nil {
return err return err
} }
fn("success") fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
@ -331,7 +351,7 @@ func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerRead
return layers[:j] return layers[:j]
} }
func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error { func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
// Write each of the layers to disk // Write each of the layers to disk
for _, layer := range layers { for _, layer := range layers {
fp, err := GetBlobsPath(layer.Digest) fp, err := GetBlobsPath(layer.Digest)
@ -341,7 +361,8 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
_, err = os.Stat(fp) _, err = os.Stat(fp)
if os.IsNotExist(err) || force { if os.IsNotExist(err) || force {
fn(fmt.Sprintf("writing layer %s", layer.Digest)) fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
out, err := os.Create(fp) out, err := os.Create(fp)
if err != nil { if err != nil {
log.Printf("couldn't create %s", fp) log.Printf("couldn't create %s", fp)
@ -354,7 +375,7 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
} }
} else { } else {
fn(fmt.Sprintf("using already created layer %s", layer.Digest)) fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
} }
} }

View file

@ -147,10 +147,8 @@ func CreateModelHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(status string) { fn := func(resp api.ProgressResponse) {
ch <- api.CreateProgress{ ch <- resp
Status: status,
}
} }
if err := CreateModel(req.Name, req.Path, fn); err != nil { if err := CreateModel(req.Name, req.Path, fn); err != nil {