ollama/server/models.go

141 lines
3.2 KiB
Go
Raw Normal View History

2023-07-06 12:24:49 -04:00
package server
import (
"encoding/json"
2023-07-11 13:36:35 -07:00
"errors"
2023-07-06 12:24:49 -04:00
"fmt"
"io"
"net/http"
"os"
"path"
"strconv"
)
2023-07-07 15:13:41 -04:00
const directoryURL = "https://ollama.ai/api/models"
2023-07-06 12:24:49 -04:00
type Model struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Parameters string `json:"parameters"`
URL string `json:"url"`
ShortDescription string `json:"short_description"`
Description string `json:"description"`
PublishedBy string `json:"published_by"`
OriginalAuthor string `json:"original_author"`
OriginalURL string `json:"original_url"`
License string `json:"license"`
}
2023-07-06 15:43:04 -07:00
func (m *Model) FullName() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return path.Join(home, ".ollama", "models", m.Name+".bin")
}
2023-07-11 13:36:35 -07:00
func (m *Model) TempFile() string {
fullName := m.FullName()
return path.Join(
path.Dir(fullName),
fmt.Sprintf(".%s.part", path.Base(fullName)),
)
}
2023-07-06 12:24:49 -04:00
func getRemote(model string) (*Model, error) {
// resolve the model download from our directory
resp, err := http.Get(directoryURL)
if err != nil {
return nil, fmt.Errorf("failed to get directory: %w", err)
}
defer resp.Body.Close()
2023-07-06 15:43:04 -07:00
body, err := io.ReadAll(resp.Body)
2023-07-06 12:24:49 -04:00
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
var models []Model
err = json.Unmarshal(body, &models)
if err != nil {
return nil, fmt.Errorf("failed to parse directory: %w", err)
}
for _, m := range models {
if m.Name == model {
return &m, nil
}
}
return nil, fmt.Errorf("model not found in directory: %s", model)
}
2023-07-11 11:54:22 -07:00
func saveModel(model *Model, fn func(total, completed int64)) error {
2023-07-06 12:24:49 -04:00
// this models cache directory is created by the server on startup
client := &http.Client{}
req, err := http.NewRequest("GET", model.URL, nil)
if err != nil {
2023-07-06 15:03:52 -04:00
return fmt.Errorf("failed to download model: %w", err)
2023-07-06 12:24:49 -04:00
}
2023-07-11 13:36:35 -07:00
// check if completed file exists
fi, err := os.Stat(model.FullName())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
fn(fi.Size(), fi.Size())
return nil
}
var size int64
// completed file doesn't exist, check partial file
fi, err = os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
2023-07-06 12:24:49 -04:00
}
2023-07-11 13:36:35 -07:00
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
2023-07-06 12:24:49 -04:00
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
defer resp.Body.Close()
2023-07-11 13:36:35 -07:00
if resp.StatusCode >= 400 {
2023-07-06 12:24:49 -04:00
return fmt.Errorf("failed to download model: %s", resp.Status)
}
2023-07-11 13:36:35 -07:00
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
2023-07-06 12:24:49 -04:00
if err != nil {
panic(err)
}
defer out.Close()
2023-07-12 09:55:07 -07:00
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
completed := size
2023-07-06 12:24:49 -04:00
2023-07-12 09:55:07 -07:00
total := remaining + completed
2023-07-06 12:24:49 -04:00
for {
2023-07-12 09:55:07 -07:00
fn(total, completed)
if completed >= total {
return os.Rename(model.TempFile(), model.FullName())
2023-07-06 12:24:49 -04:00
}
2023-07-11 13:36:35 -07:00
2023-07-12 09:55:07 -07:00
n , err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) {
return err
2023-07-06 12:24:49 -04:00
}
2023-07-12 09:55:07 -07:00
completed += n
2023-07-06 14:18:40 -04:00
}
2023-07-06 12:24:49 -04:00
}