diff --git a/api/client.go b/api/client.go index f2349df5..9ba668d0 100644 --- a/api/client.go +++ b/api/client.go @@ -35,7 +35,7 @@ func checkError(resp *http.Response, body []byte) error { return apiError } -func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func (data []byte)) error { +func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error { var reqBody io.Reader var data []byte var err error @@ -140,3 +140,14 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu return &res, nil } + +func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) { + var res PullResponse + if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) { + callback(string(token)) + }); err != nil { + return nil, err + } + + return &res, nil +} diff --git a/api/types.go b/api/types.go index 5f104415..d1810726 100644 --- a/api/types.go +++ b/api/types.go @@ -18,6 +18,14 @@ func (e Error) Error() string { return e.Message } +type PullRequest struct { + Model string `json:"model"` +} + +type PullResponse struct { + Response string `json:"response"` +} + type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` diff --git a/cmd/cmd.go b/cmd/cmd.go index c2b883e9..6ede5a39 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "fmt" "log" "net" "net/http" @@ -23,6 +24,21 @@ func cacheDir() string { return path.Join(home, ".ollama") } +func run(model string) error { + client, err := NewAPIClient() + if err != nil { + return err + } + pr := api.PullRequest{ + Model: model, + } + callback := func(progress string) { + fmt.Println(progress) + } + _, err = client.Pull(context.Background(), &pr, callback) + return err +} + func serve() error { sp := path.Join(cacheDir(), "ollama.sock") @@ -94,7 +110,7 @@ func NewCLI() *cobra.Command { Short: "Run a model", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return nil + return run(args[0]) }, } diff --git a/server/models.go b/server/models.go new file mode 100644 index 00000000..d5504390 --- /dev/null +++ b/server/models.go @@ -0,0 +1,140 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "path" + "strconv" +) + +// const directoryURL = "https://ollama.ai/api/models" +const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json" + +type directoryCtxKey string + +var dirCtx directoryCtxKey = "directory" + +type Model struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` + Parameters string `json:"parameters"` + URL string `json:"url"` + ShortDescription string `json:"short_description"` + Description string `json:"description"` + PublishedBy string `json:"published_by"` + OriginalAuthor string `json:"original_author"` + OriginalURL string `json:"original_url"` + License string `json:"license"` +} + +func pull(model string, progressCh chan<- string) error { + remote, err := getRemote(model) + if err != nil { + return fmt.Errorf("failed to pull model: %w", err) + } + + return saveModel(remote, progressCh) +} + +func getRemote(model string) (*Model, error) { + // resolve the model download from our directory + resp, err := http.Get(directoryURL) + if err != nil { + return nil, fmt.Errorf("failed to get directory: %w", err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read directory: %w", err) + } + var models []Model + err = json.Unmarshal(body, &models) + if err != nil { + return nil, fmt.Errorf("failed to parse directory: %w", err) + } + for _, m := range models { + if m.Name == model { + return &m, nil + } + } + return nil, fmt.Errorf("model not found in directory: %s", model) +} + +func saveModel(model *Model, progressCh chan<- string) error { + // this models cache directory is created by the server on startup + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get home directory: %w", err) + } + modelsCache := path.Join(home, ".ollama", "models") + + fileName := path.Join(modelsCache, model.Name+".bin") + + client := &http.Client{} + req, err := http.NewRequest("GET", model.URL, nil) + if err != nil { + panic(err) + } + // check for resume + fileInfo, err := os.Stat(fileName) + if err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to check resume model file: %w", err) + } + // file doesn't exist, create it now + } else { + req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-") + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to download model: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download model: %s", resp.Status) + } + + out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + panic(err) + } + defer out.Close() + + totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length")) + + buf := make([]byte, 1024) + totalBytes := 0 + + for { + n, err := resp.Body.Read(buf) + + if err != nil && err != io.EOF { + return err + } + + if n == 0 { + break + } + + if _, err := out.Write(buf[:n]); err != nil { + return err + } + + totalBytes += n + + // send progress updates + progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100) + } + + // send completion message + progressCh <- "Download complete!" + + return nil +} diff --git a/server/routes.go b/server/routes.go index cacd36a0..5eb2a048 100644 --- a/server/routes.go +++ b/server/routes.go @@ -14,12 +14,6 @@ import ( "github.com/jmorganca/ollama/api" ) -func pull(c *gin.Context) { - // TODO - - c.JSON(http.StatusOK, gin.H{"message": "ok"}) -} - func generate(c *gin.Context) { // TODO: these should be request parameters gpulayers := 1 @@ -65,7 +59,31 @@ func generate(c *gin.Context) { func Serve(ln net.Listener) error { r := gin.Default() - r.POST("api/pull", pull) + r.POST("api/pull", func(c *gin.Context) { + var req api.PullRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + + progressCh := make(chan string) + go func() { + defer close(progressCh) + if err := pull(req.Model, progressCh); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + }() + + c.Stream(func(w io.Writer) bool { + progress, ok := <-progressCh + if !ok { + return false + } + c.SSEvent("progress", progress) + return true + }) + }) r.POST("/api/generate", generate)