add rm command for models (#151)
This commit is contained in:
parent
8454f298ac
commit
e7a393de54
5 changed files with 166 additions and 25 deletions
|
@ -210,3 +210,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||||
}
|
}
|
||||||
return &lr, nil
|
return &lr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeleteProgressFunc func(ProgressResponse) error
|
||||||
|
|
||||||
|
func (c *Client) Delete(ctx context.Context, req *DeleteRequest, fn DeleteProgressFunc) error {
|
||||||
|
return c.stream(ctx, http.MethodDelete, "/api/delete", req, func(bts []byte) error {
|
||||||
|
var resp ProgressResponse
|
||||||
|
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
12
api/types.go
12
api/types.go
|
@ -37,6 +37,10 @@ type CreateProgress struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeleteRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
@ -44,10 +48,10 @@ type PullRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProgressResponse struct {
|
type ProgressResponse struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Digest string `json:"digest,omitempty"`
|
Digest string `json:"digest,omitempty"`
|
||||||
Total int `json:"total,omitempty"`
|
Total int `json:"total,omitempty"`
|
||||||
Completed int `json:"completed,omitempty"`
|
Completed int `json:"completed,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PushRequest struct {
|
type PushRequest struct {
|
||||||
|
|
45
cmd/cmd.go
45
cmd/cmd.go
|
@ -25,7 +25,7 @@ import (
|
||||||
"github.com/jmorganca/ollama/server"
|
"github.com/jmorganca/ollama/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func create(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
filename, _ := cmd.Flags().GetString("file")
|
filename, _ := cmd.Flags().GetString("file")
|
||||||
filename, err := filepath.Abs(filename)
|
filename, err := filepath.Abs(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -59,7 +59,7 @@ func create(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunRun(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
mp := server.ParseModelPath(args[0])
|
mp := server.ParseModelPath(args[0])
|
||||||
fp, err := mp.GetManifestPath(false)
|
fp, err := mp.GetManifestPath(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -86,7 +86,7 @@ func RunRun(cmd *cobra.Command, args []string) error {
|
||||||
return RunGenerate(cmd, args)
|
return RunGenerate(cmd, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func push(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client := api.NewClient()
|
client := api.NewClient()
|
||||||
|
|
||||||
request := api.PushRequest{Name: args[0]}
|
request := api.PushRequest{Name: args[0]}
|
||||||
|
@ -101,7 +101,7 @@ func push(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func list(cmd *cobra.Command, args []string) error {
|
func ListHandler(cmd *cobra.Command, args []string) error {
|
||||||
client := api.NewClient()
|
client := api.NewClient()
|
||||||
|
|
||||||
models, err := client.List(context.Background())
|
models, err := client.List(context.Background())
|
||||||
|
@ -131,7 +131,22 @@ func list(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunPull(cmd *cobra.Command, args []string) error {
|
func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
client := api.NewClient()
|
||||||
|
|
||||||
|
request := api.DeleteRequest{Name: args[0]}
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
fmt.Println(resp.Status)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Delete(context.Background(), &request, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||||
return pull(args[0])
|
return pull(args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,7 +305,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(line, "/list"):
|
case strings.HasPrefix(line, "/list"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
if err := list(cmd, args[1:]); err != nil {
|
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,7 +402,7 @@ func NewCLI() *cobra.Command {
|
||||||
Use: "create MODEL",
|
Use: "create MODEL",
|
||||||
Short: "Create a model from a Modelfile",
|
Short: "Create a model from a Modelfile",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
RunE: create,
|
RunE: CreateHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||||
|
@ -396,7 +411,7 @@ func NewCLI() *cobra.Command {
|
||||||
Use: "run MODEL [PROMPT]",
|
Use: "run MODEL [PROMPT]",
|
||||||
Short: "Run a model",
|
Short: "Run a model",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
RunE: RunRun,
|
RunE: RunHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||||
|
@ -412,21 +427,28 @@ func NewCLI() *cobra.Command {
|
||||||
Use: "pull MODEL",
|
Use: "pull MODEL",
|
||||||
Short: "Pull a model from a registry",
|
Short: "Pull a model from a registry",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
RunE: RunPull,
|
RunE: PullHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
pushCmd := &cobra.Command{
|
pushCmd := &cobra.Command{
|
||||||
Use: "push MODEL",
|
Use: "push MODEL",
|
||||||
Short: "Push a model to a registry",
|
Short: "Push a model to a registry",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
RunE: push,
|
RunE: PushHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
listCmd := &cobra.Command{
|
listCmd := &cobra.Command{
|
||||||
Use: "list",
|
Use: "list",
|
||||||
Aliases: []string{"ls"},
|
Aliases: []string{"ls"},
|
||||||
Short: "List models",
|
Short: "List models",
|
||||||
RunE: list,
|
RunE: ListHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
deleteCmd := &cobra.Command{
|
||||||
|
Use: "rm",
|
||||||
|
Short: "Remove a model",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: DeleteHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCmd.AddCommand(
|
rootCmd.AddCommand(
|
||||||
|
@ -436,6 +458,7 @@ func NewCLI() *cobra.Command {
|
||||||
pullCmd,
|
pullCmd,
|
||||||
pushCmd,
|
pushCmd,
|
||||||
listCmd,
|
listCmd,
|
||||||
|
deleteCmd,
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|
|
@ -487,6 +487,83 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||||
return layer, nil
|
return layer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteModel(name string, fn func(api.ProgressResponse)) error {
|
||||||
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
|
manifest, err := GetManifest(mp)
|
||||||
|
if err != nil {
|
||||||
|
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
deleteMap := make(map[string]bool)
|
||||||
|
for _, layer := range manifest.Layers {
|
||||||
|
deleteMap[layer.Digest] = true
|
||||||
|
}
|
||||||
|
deleteMap[manifest.Config.Digest] = true
|
||||||
|
|
||||||
|
fp, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
fn(api.ProgressResponse{Status: "problem getting manifest path"})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
fn(api.ProgressResponse{Status: "problem walking manifest dir"})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
path := path[len(fp)+1:]
|
||||||
|
slashIndex := strings.LastIndex(path, "/")
|
||||||
|
if slashIndex == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||||
|
fmp := ParseModelPath(tag)
|
||||||
|
|
||||||
|
// skip the manifest we're trying to delete
|
||||||
|
if mp.GetFullTagname() == fmp.GetFullTagname() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||||
|
manifest, err := GetManifest(fmp)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("skipping file: %s", fp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, layer := range manifest.Layers {
|
||||||
|
delete(deleteMap, layer.Digest)
|
||||||
|
}
|
||||||
|
delete(deleteMap, manifest.Config.Digest)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// only delete the files which are still in the deleteMap
|
||||||
|
for k, v := range deleteMap {
|
||||||
|
if v {
|
||||||
|
err := os.Remove(k)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("couldn't remove file '%s': %v", k, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fp, err = mp.GetManifestPath(false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Remove(fp)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("couldn't remove manifest file '%s': %v", fp, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fn(api.ProgressResponse{Status: fmt.Sprintf("deleted '%s'", name)})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
|
func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generate(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
|
@ -78,7 +78,7 @@ func generate(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func pull(c *gin.Context) {
|
func PullModelHandler(c *gin.Context) {
|
||||||
var req api.PullRequest
|
var req api.PullRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
@ -100,7 +100,7 @@ func pull(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func push(c *gin.Context) {
|
func PushModelHandler(c *gin.Context) {
|
||||||
var req api.PushRequest
|
var req api.PushRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
@ -122,7 +122,7 @@ func push(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(c *gin.Context) {
|
func CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var req api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
|
@ -146,7 +146,30 @@ func create(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func list(c *gin.Context) {
|
func DeleteModelHandler(c *gin.Context) {
|
||||||
|
var req api.DeleteRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan any)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
fn := func(r api.ProgressResponse) {
|
||||||
|
ch <- r
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DeleteModel(req.Name, fn); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
streamResponse(c, ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListModelsHandler(c *gin.Context) {
|
||||||
var models []api.ListResponseModel
|
var models []api.ListResponseModel
|
||||||
fp, err := GetManifestPath()
|
fp, err := GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -199,11 +222,12 @@ func Serve(ln net.Listener) error {
|
||||||
c.String(http.StatusOK, "Ollama is running")
|
c.String(http.StatusOK, "Ollama is running")
|
||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/api/pull", pull)
|
r.POST("/api/pull", PullModelHandler)
|
||||||
r.POST("/api/generate", generate)
|
r.POST("/api/generate", GenerateHandler)
|
||||||
r.POST("/api/create", create)
|
r.POST("/api/create", CreateModelHandler)
|
||||||
r.POST("/api/push", push)
|
r.POST("/api/push", PushModelHandler)
|
||||||
r.GET("/api/tags", list)
|
r.GET("/api/tags", ListModelsHandler)
|
||||||
|
r.DELETE("/api/delete", DeleteModelHandler)
|
||||||
|
|
||||||
log.Printf("Listening on %s", ln.Addr())
|
log.Printf("Listening on %s", ln.Addr())
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
|
|
Loading…
Reference in a new issue