2023-07-06 12:24:49 -04:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"path"
|
|
|
|
"strconv"
|
2023-07-06 14:18:40 -04:00
|
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
2023-07-06 12:24:49 -04:00
|
|
|
)
|
|
|
|
|
2023-07-07 15:13:41 -04:00
|
|
|
const directoryURL = "https://ollama.ai/api/models"
|
2023-07-06 12:24:49 -04:00
|
|
|
|
|
|
|
type Model struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
DisplayName string `json:"display_name"`
|
|
|
|
Parameters string `json:"parameters"`
|
|
|
|
URL string `json:"url"`
|
|
|
|
ShortDescription string `json:"short_description"`
|
|
|
|
Description string `json:"description"`
|
|
|
|
PublishedBy string `json:"published_by"`
|
|
|
|
OriginalAuthor string `json:"original_author"`
|
|
|
|
OriginalURL string `json:"original_url"`
|
|
|
|
License string `json:"license"`
|
|
|
|
}
|
|
|
|
|
2023-07-06 15:43:04 -07:00
|
|
|
func (m *Model) FullName() string {
|
|
|
|
home, err := os.UserHomeDir()
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
|
|
|
}
|
|
|
|
|
2023-07-06 14:18:40 -04:00
|
|
|
func pull(model string, progressCh chan<- api.PullProgress) error {
|
2023-07-06 12:24:49 -04:00
|
|
|
remote, err := getRemote(model)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to pull model: %w", err)
|
|
|
|
}
|
|
|
|
return saveModel(remote, progressCh)
|
|
|
|
}
|
|
|
|
|
|
|
|
func getRemote(model string) (*Model, error) {
|
|
|
|
// resolve the model download from our directory
|
|
|
|
resp, err := http.Get(directoryURL)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to get directory: %w", err)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
2023-07-06 15:43:04 -07:00
|
|
|
body, err := io.ReadAll(resp.Body)
|
2023-07-06 12:24:49 -04:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to read directory: %w", err)
|
|
|
|
}
|
|
|
|
var models []Model
|
|
|
|
err = json.Unmarshal(body, &models)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to parse directory: %w", err)
|
|
|
|
}
|
|
|
|
for _, m := range models {
|
|
|
|
if m.Name == model {
|
|
|
|
return &m, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
|
|
|
}
|
|
|
|
|
2023-07-06 14:18:40 -04:00
|
|
|
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
2023-07-06 12:24:49 -04:00
|
|
|
// this models cache directory is created by the server on startup
|
|
|
|
|
|
|
|
client := &http.Client{}
|
|
|
|
req, err := http.NewRequest("GET", model.URL, nil)
|
|
|
|
if err != nil {
|
2023-07-06 15:03:52 -04:00
|
|
|
return fmt.Errorf("failed to download model: %w", err)
|
2023-07-06 12:24:49 -04:00
|
|
|
}
|
|
|
|
// check for resume
|
2023-07-06 14:05:55 -07:00
|
|
|
alreadyDownloaded := int64(0)
|
2023-07-06 15:43:04 -07:00
|
|
|
fileInfo, err := os.Stat(model.FullName())
|
2023-07-06 12:24:49 -04:00
|
|
|
if err != nil {
|
|
|
|
if !os.IsNotExist(err) {
|
|
|
|
return fmt.Errorf("failed to check resume model file: %w", err)
|
|
|
|
}
|
|
|
|
// file doesn't exist, create it now
|
|
|
|
} else {
|
2023-07-06 14:05:55 -07:00
|
|
|
alreadyDownloaded = fileInfo.Size()
|
|
|
|
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
|
2023-07-06 12:24:49 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to download model: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
2023-07-06 14:57:11 -04:00
|
|
|
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
|
|
|
// already downloaded
|
|
|
|
progressCh <- api.PullProgress{
|
|
|
|
Total: alreadyDownloaded,
|
|
|
|
Completed: alreadyDownloaded,
|
|
|
|
Percent: 100,
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
2023-07-06 12:24:49 -04:00
|
|
|
return fmt.Errorf("failed to download model: %s", resp.Status)
|
|
|
|
}
|
|
|
|
|
2023-07-06 15:43:04 -07:00
|
|
|
out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
2023-07-06 12:24:49 -04:00
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
defer out.Close()
|
|
|
|
|
2023-07-06 14:05:55 -07:00
|
|
|
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
2023-07-06 12:24:49 -04:00
|
|
|
|
|
|
|
buf := make([]byte, 1024)
|
2023-07-06 14:57:11 -04:00
|
|
|
totalBytes := alreadyDownloaded
|
|
|
|
totalSize += alreadyDownloaded
|
2023-07-06 12:24:49 -04:00
|
|
|
|
|
|
|
for {
|
|
|
|
n, err := resp.Body.Read(buf)
|
|
|
|
if err != nil && err != io.EOF {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if n == 0 {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
if _, err := out.Write(buf[:n]); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-07-06 14:05:55 -07:00
|
|
|
|
|
|
|
totalBytes += int64(n)
|
2023-07-06 12:24:49 -04:00
|
|
|
|
|
|
|
// send progress updates
|
2023-07-06 14:18:40 -04:00
|
|
|
progressCh <- api.PullProgress{
|
|
|
|
Total: totalSize,
|
|
|
|
Completed: totalBytes,
|
|
|
|
Percent: float64(totalBytes) / float64(totalSize) * 100,
|
|
|
|
}
|
2023-07-06 12:24:49 -04:00
|
|
|
}
|
|
|
|
|
2023-07-06 14:18:40 -04:00
|
|
|
progressCh <- api.PullProgress{
|
|
|
|
Total: totalSize,
|
|
|
|
Completed: totalSize,
|
|
|
|
Percent: 100,
|
|
|
|
}
|
2023-07-06 12:24:49 -04:00
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|