add rm command for models (#151)

This commit is contained in:
Patrick Devine 2023-07-20 16:09:23 -07:00 committed by GitHub
parent 8454f298ac
commit e7a393de54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 166 additions and 25 deletions

View file

@ -210,3 +210,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
} }
return &lr, nil return &lr, nil
} }
type DeleteProgressFunc func(ProgressResponse) error
func (c *Client) Delete(ctx context.Context, req *DeleteRequest, fn DeleteProgressFunc) error {
return c.stream(ctx, http.MethodDelete, "/api/delete", req, func(bts []byte) error {
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}

View file

@ -37,6 +37,10 @@ type CreateProgress struct {
Status string `json:"status"` Status string `json:"status"`
} }
type DeleteRequest struct {
Name string `json:"name"`
}
type PullRequest struct { type PullRequest struct {
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`

View file

@ -25,7 +25,7 @@ import (
"github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/server"
) )
func create(cmd *cobra.Command, args []string) error { func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file") filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename) filename, err := filepath.Abs(filename)
if err != nil { if err != nil {
@ -59,7 +59,7 @@ func create(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func RunRun(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
mp := server.ParseModelPath(args[0]) mp := server.ParseModelPath(args[0])
fp, err := mp.GetManifestPath(false) fp, err := mp.GetManifestPath(false)
if err != nil { if err != nil {
@ -86,7 +86,7 @@ func RunRun(cmd *cobra.Command, args []string) error {
return RunGenerate(cmd, args) return RunGenerate(cmd, args)
} }
func push(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client := api.NewClient() client := api.NewClient()
request := api.PushRequest{Name: args[0]} request := api.PushRequest{Name: args[0]}
@ -101,7 +101,7 @@ func push(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func list(cmd *cobra.Command, args []string) error { func ListHandler(cmd *cobra.Command, args []string) error {
client := api.NewClient() client := api.NewClient()
models, err := client.List(context.Background()) models, err := client.List(context.Background())
@ -131,7 +131,22 @@ func list(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func RunPull(cmd *cobra.Command, args []string) error { func DeleteHandler(cmd *cobra.Command, args []string) error {
client := api.NewClient()
request := api.DeleteRequest{Name: args[0]}
fn := func(resp api.ProgressResponse) error {
fmt.Println(resp.Status)
return nil
}
if err := client.Delete(context.Background(), &request, fn); err != nil {
return err
}
return nil
}
func PullHandler(cmd *cobra.Command, args []string) error {
return pull(args[0]) return pull(args[0])
} }
@ -290,7 +305,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
switch { switch {
case strings.HasPrefix(line, "/list"): case strings.HasPrefix(line, "/list"):
args := strings.Fields(line) args := strings.Fields(line)
if err := list(cmd, args[1:]); err != nil { if err := ListHandler(cmd, args[1:]); err != nil {
return err return err
} }
@ -387,7 +402,7 @@ func NewCLI() *cobra.Command {
Use: "create MODEL", Use: "create MODEL",
Short: "Create a model from a Modelfile", Short: "Create a model from a Modelfile",
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
RunE: create, RunE: CreateHandler,
} }
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")") createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
@ -396,7 +411,7 @@ func NewCLI() *cobra.Command {
Use: "run MODEL [PROMPT]", Use: "run MODEL [PROMPT]",
Short: "Run a model", Short: "Run a model",
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
RunE: RunRun, RunE: RunHandler,
} }
runCmd.Flags().Bool("verbose", false, "Show timings for response") runCmd.Flags().Bool("verbose", false, "Show timings for response")
@ -412,21 +427,28 @@ func NewCLI() *cobra.Command {
Use: "pull MODEL", Use: "pull MODEL",
Short: "Pull a model from a registry", Short: "Pull a model from a registry",
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
RunE: RunPull, RunE: PullHandler,
} }
pushCmd := &cobra.Command{ pushCmd := &cobra.Command{
Use: "push MODEL", Use: "push MODEL",
Short: "Push a model to a registry", Short: "Push a model to a registry",
Args: cobra.MinimumNArgs(1), Args: cobra.MinimumNArgs(1),
RunE: push, RunE: PushHandler,
} }
listCmd := &cobra.Command{ listCmd := &cobra.Command{
Use: "list", Use: "list",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
Short: "List models", Short: "List models",
RunE: list, RunE: ListHandler,
}
deleteCmd := &cobra.Command{
Use: "rm",
Short: "Remove a model",
Args: cobra.MinimumNArgs(1),
RunE: DeleteHandler,
} }
rootCmd.AddCommand( rootCmd.AddCommand(
@ -436,6 +458,7 @@ func NewCLI() *cobra.Command {
pullCmd, pullCmd,
pushCmd, pushCmd,
listCmd, listCmd,
deleteCmd,
) )
return rootCmd return rootCmd

View file

@ -487,6 +487,83 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
return layer, nil return layer, nil
} }
func DeleteModel(name string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
deleteMap := make(map[string]bool)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
fp, err := GetManifestPath()
if err != nil {
fn(api.ProgressResponse{Status: "problem getting manifest path"})
return err
}
err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
if err != nil {
fn(api.ProgressResponse{Status: "problem walking manifest dir"})
return err
}
if !info.IsDir() {
path := path[len(fp)+1:]
slashIndex := strings.LastIndex(path, "/")
if slashIndex == -1 {
return nil
}
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
fmp := ParseModelPath(tag)
// skip the manifest we're trying to delete
if mp.GetFullTagname() == fmp.GetFullTagname() {
return nil
}
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, err := GetManifest(fmp)
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
}
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
}
return nil
})
// only delete the files which are still in the deleteMap
for k, v := range deleteMap {
if v {
err := os.Remove(k)
if err != nil {
log.Printf("couldn't remove file '%s': %v", k, err)
continue
}
}
}
fp, err = mp.GetManifestPath(false)
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
log.Printf("couldn't remove manifest file '%s': %v", fp, err)
return err
}
fn(api.ProgressResponse{Status: fmt.Sprintf("deleted '%s'", name)})
return nil
}
func PushModel(name, username, password string, fn func(api.ProgressResponse)) error { func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)

View file

@ -18,7 +18,7 @@ import (
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
) )
func generate(c *gin.Context) { func GenerateHandler(c *gin.Context) {
start := time.Now() start := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
@ -78,7 +78,7 @@ func generate(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func pull(c *gin.Context) { func PullModelHandler(c *gin.Context) {
var req api.PullRequest var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -100,7 +100,7 @@ func pull(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func push(c *gin.Context) { func PushModelHandler(c *gin.Context) {
var req api.PushRequest var req api.PushRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -122,7 +122,7 @@ func push(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func create(c *gin.Context) { func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
@ -146,7 +146,30 @@ func create(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func list(c *gin.Context) { func DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r api.ProgressResponse) {
ch <- r
}
if err := DeleteModel(req.Name, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}()
streamResponse(c, ch)
}
func ListModelsHandler(c *gin.Context) {
var models []api.ListResponseModel var models []api.ListResponseModel
fp, err := GetManifestPath() fp, err := GetManifestPath()
if err != nil { if err != nil {
@ -199,11 +222,12 @@ func Serve(ln net.Listener) error {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.POST("/api/pull", pull) r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", generate) r.POST("/api/generate", GenerateHandler)
r.POST("/api/create", create) r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", push) r.POST("/api/push", PushModelHandler)
r.GET("/api/tags", list) r.GET("/api/tags", ListModelsHandler)
r.DELETE("/api/delete", DeleteModelHandler)
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())
s := &http.Server{ s := &http.Server{