pull models
This commit is contained in:
parent
0833f5af3a
commit
a6494f8211
5 changed files with 202 additions and 9 deletions
|
@ -35,7 +35,7 @@ func checkError(resp *http.Response, body []byte) error {
|
|||
return apiError
|
||||
}
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func (data []byte)) error {
|
||||
func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
|
||||
var reqBody io.Reader
|
||||
var data []byte
|
||||
var err error
|
||||
|
@ -140,3 +140,14 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
|
|||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
|
||||
var res PullResponse
|
||||
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
|
||||
callback(string(token))
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
|
|
@ -18,6 +18,14 @@ func (e Error) Error() string {
|
|||
return e.Message
|
||||
}
|
||||
|
||||
type PullRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type PullResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
|
|
18
cmd/cmd.go
18
cmd/cmd.go
|
@ -2,6 +2,7 @@ package cmd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -23,6 +24,21 @@ func cacheDir() string {
|
|||
return path.Join(home, ".ollama")
|
||||
}
|
||||
|
||||
func run(model string) error {
|
||||
client, err := NewAPIClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pr := api.PullRequest{
|
||||
Model: model,
|
||||
}
|
||||
callback := func(progress string) {
|
||||
fmt.Println(progress)
|
||||
}
|
||||
_, err = client.Pull(context.Background(), &pr, callback)
|
||||
return err
|
||||
}
|
||||
|
||||
func serve() error {
|
||||
sp := path.Join(cacheDir(), "ollama.sock")
|
||||
|
||||
|
@ -94,7 +110,7 @@ func NewCLI() *cobra.Command {
|
|||
Short: "Run a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
return run(args[0])
|
||||
},
|
||||
}
|
||||
|
||||
|
|
140
server/models.go
Normal file
140
server/models.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// const directoryURL = "https://ollama.ai/api/models"
|
||||
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
|
||||
|
||||
type directoryCtxKey string
|
||||
|
||||
var dirCtx directoryCtxKey = "directory"
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Parameters string `json:"parameters"`
|
||||
URL string `json:"url"`
|
||||
ShortDescription string `json:"short_description"`
|
||||
Description string `json:"description"`
|
||||
PublishedBy string `json:"published_by"`
|
||||
OriginalAuthor string `json:"original_author"`
|
||||
OriginalURL string `json:"original_url"`
|
||||
License string `json:"license"`
|
||||
}
|
||||
|
||||
func pull(model string, progressCh chan<- string) error {
|
||||
remote, err := getRemote(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pull model: %w", err)
|
||||
}
|
||||
|
||||
return saveModel(remote, progressCh)
|
||||
}
|
||||
|
||||
func getRemote(model string) (*Model, error) {
|
||||
// resolve the model download from our directory
|
||||
resp, err := http.Get(directoryURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get directory: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
var models []Model
|
||||
err = json.Unmarshal(body, &models)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse directory: %w", err)
|
||||
}
|
||||
for _, m := range models {
|
||||
if m.Name == model {
|
||||
return &m, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("model not found in directory: %s", model)
|
||||
}
|
||||
|
||||
func saveModel(model *Model, progressCh chan<- string) error {
|
||||
// this models cache directory is created by the server on startup
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
modelsCache := path.Join(home, ".ollama", "models")
|
||||
|
||||
fileName := path.Join(modelsCache, model.Name+".bin")
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", model.URL, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// check for resume
|
||||
fileInfo, err := os.Stat(fileName)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to check resume model file: %w", err)
|
||||
}
|
||||
// file doesn't exist, create it now
|
||||
} else {
|
||||
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download model: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||
}
|
||||
|
||||
out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
totalBytes := 0
|
||||
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if _, err := out.Write(buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalBytes += n
|
||||
|
||||
// send progress updates
|
||||
progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
|
||||
}
|
||||
|
||||
// send completion message
|
||||
progressCh <- "Download complete!"
|
||||
|
||||
return nil
|
||||
}
|
|
@ -14,12 +14,6 @@ import (
|
|||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func pull(c *gin.Context) {
|
||||
// TODO
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func generate(c *gin.Context) {
|
||||
// TODO: these should be request parameters
|
||||
gpulayers := 1
|
||||
|
@ -65,7 +59,31 @@ func generate(c *gin.Context) {
|
|||
func Serve(ln net.Listener) error {
|
||||
r := gin.Default()
|
||||
|
||||
r.POST("api/pull", pull)
|
||||
r.POST("api/pull", func(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
progressCh := make(chan string)
|
||||
go func() {
|
||||
defer close(progressCh)
|
||||
if err := pull(req.Model, progressCh); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
progress, ok := <-progressCh
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
c.SSEvent("progress", progress)
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
r.POST("/api/generate", generate)
|
||||
|
||||
|
|
Loading…
Reference in a new issue