pull models
This commit is contained in:
parent
0833f5af3a
commit
a6494f8211
5 changed files with 202 additions and 9 deletions
|
@ -140,3 +140,14 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
|
||||||
|
var res PullResponse
|
||||||
|
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
|
||||||
|
callback(string(token))
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,14 @@ func (e Error) Error() string {
|
||||||
return e.Message
|
return e.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PullRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PullResponse struct {
|
||||||
|
Response string `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
18
cmd/cmd.go
18
cmd/cmd.go
|
@ -2,6 +2,7 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -23,6 +24,21 @@ func cacheDir() string {
|
||||||
return path.Join(home, ".ollama")
|
return path.Join(home, ".ollama")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func run(model string) error {
|
||||||
|
client, err := NewAPIClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pr := api.PullRequest{
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
callback := func(progress string) {
|
||||||
|
fmt.Println(progress)
|
||||||
|
}
|
||||||
|
_, err = client.Pull(context.Background(), &pr, callback)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func serve() error {
|
func serve() error {
|
||||||
sp := path.Join(cacheDir(), "ollama.sock")
|
sp := path.Join(cacheDir(), "ollama.sock")
|
||||||
|
|
||||||
|
@ -94,7 +110,7 @@ func NewCLI() *cobra.Command {
|
||||||
Short: "Run a model",
|
Short: "Run a model",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return run(args[0])
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
140
server/models.go
Normal file
140
server/models.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// const directoryURL = "https://ollama.ai/api/models"
|
||||||
|
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
|
||||||
|
|
||||||
|
type directoryCtxKey string
|
||||||
|
|
||||||
|
var dirCtx directoryCtxKey = "directory"
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Parameters string `json:"parameters"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
ShortDescription string `json:"short_description"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
PublishedBy string `json:"published_by"`
|
||||||
|
OriginalAuthor string `json:"original_author"`
|
||||||
|
OriginalURL string `json:"original_url"`
|
||||||
|
License string `json:"license"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func pull(model string, progressCh chan<- string) error {
|
||||||
|
remote, err := getRemote(model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to pull model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return saveModel(remote, progressCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRemote(model string) (*Model, error) {
|
||||||
|
// resolve the model download from our directory
|
||||||
|
resp, err := http.Get(directoryURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get directory: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||||
|
}
|
||||||
|
var models []Model
|
||||||
|
err = json.Unmarshal(body, &models)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse directory: %w", err)
|
||||||
|
}
|
||||||
|
for _, m := range models {
|
||||||
|
if m.Name == model {
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveModel(model *Model, progressCh chan<- string) error {
|
||||||
|
// this models cache directory is created by the server on startup
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
modelsCache := path.Join(home, ".ollama", "models")
|
||||||
|
|
||||||
|
fileName := path.Join(modelsCache, model.Name+".bin")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest("GET", model.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// check for resume
|
||||||
|
fileInfo, err := os.Stat(fileName)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to check resume model file: %w", err)
|
||||||
|
}
|
||||||
|
// file doesn't exist, create it now
|
||||||
|
} else {
|
||||||
|
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to download model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
totalBytes := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := resp.Body.Read(buf)
|
||||||
|
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := out.Write(buf[:n]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
totalBytes += n
|
||||||
|
|
||||||
|
// send progress updates
|
||||||
|
progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// send completion message
|
||||||
|
progressCh <- "Download complete!"
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -14,12 +14,6 @@ import (
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func pull(c *gin.Context) {
|
|
||||||
// TODO
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func generate(c *gin.Context) {
|
func generate(c *gin.Context) {
|
||||||
// TODO: these should be request parameters
|
// TODO: these should be request parameters
|
||||||
gpulayers := 1
|
gpulayers := 1
|
||||||
|
@ -65,7 +59,31 @@ func generate(c *gin.Context) {
|
||||||
func Serve(ln net.Listener) error {
|
func Serve(ln net.Listener) error {
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
|
|
||||||
r.POST("api/pull", pull)
|
r.POST("api/pull", func(c *gin.Context) {
|
||||||
|
var req api.PullRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
progressCh := make(chan string)
|
||||||
|
go func() {
|
||||||
|
defer close(progressCh)
|
||||||
|
if err := pull(req.Model, progressCh); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
progress, ok := <-progressCh
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.SSEvent("progress", progress)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
r.POST("/api/generate", generate)
|
r.POST("/api/generate", generate)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue