ollama/server/images.go

1337 lines
31 KiB
Go
Raw Normal View History

package server
import (
"archive/zip"
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
2023-08-29 03:50:24 +00:00
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"log/slog"
"net/http"
2023-08-22 01:38:31 +00:00
"net/url"
"os"
"path/filepath"
2023-08-22 01:24:42 +00:00
"runtime"
2024-02-14 19:29:49 +00:00
"strconv"
"strings"
2023-11-29 19:11:42 +00:00
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/convert"
2024-03-13 18:03:56 +00:00
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes"
2024-04-16 23:22:38 +00:00
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
2024-02-14 19:29:49 +00:00
type registryOptions struct {
Insecure bool
Username string
Password string
Token string
}
type Model struct {
2023-11-30 18:30:23 +00:00
Name string `json:"name"`
2023-12-01 19:37:17 +00:00
Config ConfigV2
2023-11-30 18:30:23 +00:00
ShortName string
ModelPath string
2024-01-25 20:12:36 +00:00
ParentModel string
2023-11-30 18:30:23 +00:00
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Size int64
2023-11-30 18:30:23 +00:00
Options map[string]interface{}
2024-01-25 20:12:36 +00:00
Messages []Message
}
func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
}
2024-04-30 17:55:19 +00:00
func (m *Model) String() string {
var modelfile model.File
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "model",
Args: m.ModelPath,
})
if m.Template != "" {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "template",
Args: m.Template,
})
}
if m.System != "" {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "system",
Args: m.System,
})
}
for _, adapter := range m.AdapterPaths {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "adapter",
Args: adapter,
})
}
for _, projector := range m.ProjectorPaths {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "projector",
Args: projector,
})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
for _, s := range v {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", s),
})
}
default:
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", v),
})
}
}
for _, license := range m.License {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "license",
Args: license,
})
}
for _, msg := range m.Messages {
2024-04-30 17:55:19 +00:00
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "message",
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
})
}
2024-04-30 17:55:19 +00:00
return modelfile.String()
}
2024-01-25 20:12:36 +00:00
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
2023-07-21 20:33:56 +00:00
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
2023-12-01 19:37:17 +00:00
RootFS RootFS `json:"rootfs"`
}
2023-11-29 19:11:42 +00:00
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct {
Type string `json:"type"`
DiffIDs []string `json:"diff_ids"`
}
2023-09-28 17:00:34 +00:00
func (m *ManifestV2) GetTotalSize() (total int64) {
2023-07-18 16:09:45 +00:00
for _, layer := range m.Layers {
total += layer.Size
}
2023-09-28 17:00:34 +00:00
2023-07-18 16:09:45 +00:00
total += m.Config.Size
return total
}
2023-08-29 03:50:24 +00:00
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
2023-08-29 03:50:24 +00:00
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
2023-08-29 03:50:24 +00:00
return nil, "", err
}
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
2023-08-29 03:50:24 +00:00
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}
2023-08-29 03:50:24 +00:00
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil {
2023-08-29 03:50:24 +00:00
return nil, "", err
}
2023-08-29 03:50:24 +00:00
return manifest, shaStr, nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
2023-08-29 03:50:24 +00:00
manifest, digest, err := GetManifest(mp)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}
2023-12-01 19:37:17 +00:00
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
for _, layer := range manifest.Layers {
2023-07-18 05:44:21 +00:00
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
2024-01-25 20:12:36 +00:00
model.ParentModel = layer.From
2023-08-04 22:56:40 +00:00
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
2023-11-30 18:30:23 +00:00
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params":
2023-07-17 19:08:10 +00:00
params, err := os.Open(filename)
if err != nil {
return nil, err
}
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
2024-01-25 20:12:36 +00:00
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
2023-09-06 18:04:17 +00:00
case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
}
}
return model, nil
}
2023-11-21 20:43:17 +00:00
func realpath(mfDir, from string) string {
abspath, err := filepath.Abs(from)
2023-11-14 20:30:34 +00:00
if err != nil {
2023-11-21 20:43:17 +00:00
return from
}
2023-11-14 20:30:34 +00:00
home, err := os.UserHomeDir()
if err != nil {
2023-11-14 20:30:34 +00:00
return abspath
}
2023-11-21 20:43:17 +00:00
if from == "~" {
2023-11-14 20:30:34 +00:00
return home
2023-11-21 20:43:17 +00:00
} else if strings.HasPrefix(from, "~/") {
return filepath.Join(home, from[2:])
}
if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil {
// this is a file relative to the Modelfile
return filepath.Join(mfDir, from)
}
2023-11-14 20:30:34 +00:00
return abspath
}
2024-04-30 17:55:19 +00:00
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
2023-07-21 20:33:56 +00:00
config := ConfigV2{
OS: "linux",
2023-11-14 20:30:34 +00:00
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
2023-07-21 20:33:56 +00:00
}
var layers Layers
2024-01-25 20:12:36 +00:00
messages := []string{}
2023-11-14 20:30:34 +00:00
2023-07-28 15:29:00 +00:00
params := make(map[string][]string)
2023-11-14 20:30:34 +00:00
fromParams := make(map[string]any)
2024-04-30 17:55:19 +00:00
for _, c := range modelfile.Commands {
2023-11-14 20:30:34 +00:00
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name {
case "model":
2023-11-15 18:59:38 +00:00
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertModel(name, pathName, fn)
if err != nil {
var pathErr *fs.PathError
switch {
case errors.Is(err, zip.ErrFormat):
// it's not a safetensor archive
case errors.As(err, &pathErr):
// it's not a file on disk, could be a model reference
default:
return err
}
}
if ggufName != "" {
pathName = ggufName
defer os.RemoveAll(ggufName)
2024-04-05 15:49:04 +00:00
if quantization != "" {
quantization = strings.ToUpper(quantization)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
if err != nil {
return err
}
defer os.RemoveAll(tempfile.Name())
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
return err
}
if err := tempfile.Close(); err != nil {
return err
}
pathName = tempfile.Name()
}
}
bin, err := os.Open(pathName)
if err != nil {
2023-11-14 20:30:34 +00:00
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
2024-02-14 19:29:49 +00:00
if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
return err
}
2023-11-14 20:30:34 +00:00
manifest, _, err = GetManifest(modelpath)
2023-07-21 20:33:56 +00:00
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
case err != nil:
return err
}
2023-07-21 20:33:56 +00:00
fn(api.ProgressResponse{Status: "reading model metadata"})
2023-11-14 20:30:34 +00:00
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
2023-08-18 04:52:11 +00:00
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
fromConfigFile, err := os.Open(fromConfigPath)
2023-08-18 04:52:11 +00:00
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
defer fromConfigFile.Close()
2023-08-18 04:52:11 +00:00
2023-11-14 20:30:34 +00:00
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
2023-08-18 04:52:11 +00:00
return err
}
// if the model is still not in gguf format, error out
if fromConfig.ModelFormat != "gguf" {
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
}
2023-11-29 19:11:42 +00:00
config.SetModelFormat(fromConfig.ModelFormat)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
2023-08-18 04:52:11 +00:00
2023-11-14 20:30:34 +00:00
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
fromParamsFile, err := os.Open(fromParamsPath)
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
defer fromParamsFile.Close()
2023-11-14 20:30:34 +00:00
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err
}
}
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
layers.Add(layer)
}
2023-11-14 20:30:34 +00:00
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
2023-11-14 20:30:34 +00:00
defer bin.Close()
2023-11-24 19:57:20 +00:00
var offset int64
for {
fn(api.ProgressResponse{Status: "creating model layer"})
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
return err
}
2024-03-13 18:03:56 +00:00
ggml, size, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
} else if err != nil {
return err
2023-11-24 19:57:20 +00:00
}
2023-11-14 20:30:34 +00:00
2023-11-29 19:11:42 +00:00
config.SetModelFormat(ggml.Name())
2024-03-13 18:03:56 +00:00
config.SetModelFamily(ggml.KV().Architecture())
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
config.SetFileType(ggml.KV().FileType())
2023-11-24 19:57:20 +00:00
mediatype := mediatype
2024-03-13 18:03:56 +00:00
if ggml.KV().Architecture() == "clip" {
2023-11-24 19:57:20 +00:00
mediatype = "application/vnd.ollama.image.projector"
}
2024-03-13 18:03:56 +00:00
sr := io.NewSectionReader(bin, offset, size)
2023-11-24 19:57:20 +00:00
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
2024-03-13 18:03:56 +00:00
offset += size
2023-11-24 19:57:20 +00:00
}
2023-11-14 20:30:34 +00:00
case "adapter":
2023-12-01 18:50:55 +00:00
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
2023-12-05 19:57:33 +00:00
2023-11-14 20:30:34 +00:00
fn(api.ProgressResponse{Status: "creating adapter layer"})
2023-11-21 20:43:17 +00:00
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
2023-11-14 20:30:34 +00:00
return err
}
2023-11-14 20:30:34 +00:00
defer bin.Close()
2024-03-13 18:03:56 +00:00
_, size, err := llm.DecodeGGML(bin)
2024-03-08 23:38:53 +00:00
if err != nil {
return err
}
2024-03-13 18:03:56 +00:00
sr := io.NewSectionReader(bin, 0, size)
2024-03-08 23:38:53 +00:00
layer, err := NewLayer(sr, mediatype)
if err != nil {
2023-11-14 20:30:34 +00:00
return err
}
layers.Add(layer)
2023-11-14 20:30:34 +00:00
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Add(layer)
2023-11-14 20:30:34 +00:00
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Replace(layer)
2024-01-25 20:12:36 +00:00
case "message":
messages = append(messages, c.Args)
default:
2023-07-28 15:29:00 +00:00
params[c.Name] = append(params[c.Name], c.Args)
}
}
2024-01-25 20:12:36 +00:00
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
2023-07-17 19:08:10 +00:00
if len(params) > 0 {
2023-11-14 20:30:34 +00:00
fn(api.ProgressResponse{Status: "creating parameters layer"})
2023-09-02 18:38:51 +00:00
formattedParams, err := api.FormatParams(params)
if err != nil {
2023-11-14 20:30:34 +00:00
return err
}
2023-08-04 22:56:40 +00:00
2023-11-14 20:30:34 +00:00
for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v
}
}
2023-11-14 20:30:34 +00:00
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
2023-08-04 22:56:40 +00:00
return err
}
2023-11-14 20:30:34 +00:00
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
2023-11-14 20:30:34 +00:00
return err
}
2023-11-14 20:30:34 +00:00
layers.Replace(layer)
2023-08-04 22:56:40 +00:00
}
digests := make([]string, len(layers.items))
for i, layer := range layers.items {
digests[i] = layer.Digest
}
config.RootFS.DiffIDs = digests
2023-11-14 20:30:34 +00:00
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(config); err != nil {
return err
}
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
delete(deleteMap, configLayer.Digest)
for _, layer := range append(layers.items, configLayer) {
committed, err := layer.Commit()
if err != nil {
return err
}
status := "writing layer"
if !committed {
status = "using already created layer"
}
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, configLayer, layers.items); err != nil {
return err
}
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
return err
}
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
r, err := zip.OpenReader(path)
if err != nil {
return "", err
}
defer r.Close()
tempDir, err := os.MkdirTemp("", "ollama-convert")
if err != nil {
return "", err
}
defer os.RemoveAll(tempDir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
fpath := filepath.Join(tempDir, f.Name)
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return "", err
}
rc, err := f.Open()
if err != nil {
return "", err
}
_, err = io.Copy(outFile, rc)
if err != nil {
return "", err
}
outFile.Close()
rc.Close()
}
mf, err := convert.GetModelFormat(tempDir)
if err != nil {
return "", err
}
params, err := mf.GetParams(tempDir)
2024-04-01 23:14:53 +00:00
if err != nil {
return "", err
}
mArch, err := mf.GetModelArch(name, tempDir, params)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing tensors"})
2024-04-01 23:14:53 +00:00
if err := mArch.GetTensors(); err != nil {
return "", err
}
2024-04-01 23:14:53 +00:00
if err := mArch.LoadVocab(); err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "converting model"})
2024-04-01 23:14:53 +00:00
path, err = mArch.WriteGGUF()
if err != nil {
return "", err
}
return path, nil
}
2024-04-16 23:22:38 +00:00
func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() {
return model.Unqualified(dst)
}
if !src.IsFullyQualified() {
return model.Unqualified(src)
}
2024-04-29 03:47:49 +00:00
if src.Filepath() == dst.Filepath() {
return nil
}
2024-04-16 23:22:38 +00:00
manifests, err := GetManifestPath()
2023-08-22 04:56:56 +00:00
if err != nil {
return err
}
dstpath := filepath.Join(manifests, dst.Filepath())
2024-04-16 23:22:38 +00:00
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
2023-07-24 15:27:28 +00:00
srcpath := filepath.Join(manifests, src.Filepath())
2024-04-16 23:22:38 +00:00
srcfile, err := os.Open(srcpath)
2023-07-24 15:27:28 +00:00
if err != nil {
return err
}
2024-04-16 23:22:38 +00:00
defer srcfile.Close()
2023-07-24 15:27:28 +00:00
2024-04-16 23:22:38 +00:00
dstfile, err := os.Create(dstpath)
2023-07-24 15:27:28 +00:00
if err != nil {
return err
}
2024-04-16 23:22:38 +00:00
defer dstfile.Close()
2023-07-24 15:27:28 +00:00
2024-04-16 23:22:38 +00:00
_, err = io.Copy(dstfile, srcfile)
return err
2023-07-24 15:27:28 +00:00
}
2023-11-14 20:30:34 +00:00
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
2023-07-20 23:09:23 +00:00
fp, err := GetManifestPath()
if err != nil {
return err
}
2023-08-30 18:31:12 +00:00
walkFunc := func(path string, info os.FileInfo, _ error) error {
if info.IsDir() {
return nil
2023-07-20 23:09:23 +00:00
}
2023-08-30 18:31:12 +00:00
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
fmp := ParseModelPath(tag)
2023-07-20 23:09:23 +00:00
2023-08-30 18:31:12 +00:00
// skip the manifest we're trying to delete
if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
2023-08-30 18:31:12 +00:00
return nil
2023-07-20 23:09:23 +00:00
}
2023-08-30 18:31:12 +00:00
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
2023-12-15 22:07:34 +00:00
// nolint: nilerr
2023-08-30 18:31:12 +00:00
return nil
}
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
2023-07-20 23:09:23 +00:00
return nil
2023-08-30 18:31:12 +00:00
}
if err := filepath.Walk(fp, walkFunc); err != nil {
2023-07-31 22:26:18 +00:00
return err
}
2023-07-20 23:09:23 +00:00
// only delete the files which are still in the deleteMap
2023-11-14 20:30:34 +00:00
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
2023-11-14 20:30:34 +00:00
continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
continue
}
2023-11-14 20:30:34 +00:00
} else {
slog.Info(fmt.Sprintf("wanted to remove: %s", fp))
2023-07-20 23:09:23 +00:00
}
}
return nil
}
func PruneLayers() error {
2023-11-14 20:30:34 +00:00
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
}
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
return err
}
for _, blob := range blobs {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
2023-11-14 22:27:51 +00:00
if strings.HasPrefix(name, "sha256:") {
deleteMap[name] = struct{}{}
}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
return nil
}
2023-09-27 00:28:14 +00:00
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
2023-11-14 20:30:34 +00:00
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
2023-11-14 20:30:34 +00:00
deleteMap[layer.Digest] = struct{}{}
}
2023-11-14 20:30:34 +00:00
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
2023-07-20 23:09:23 +00:00
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
2023-07-20 23:09:23 +00:00
return err
}
return nil
}
2024-02-14 19:29:49 +00:00
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
2023-07-19 01:51:30 +00:00
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
2023-08-29 03:50:24 +00:00
manifest, _, err := GetManifest(mp)
if err != nil {
2023-07-19 01:51:30 +00:00
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []*Layer
2023-08-01 01:37:40 +00:00
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
for _, layer := range layers {
2023-10-09 17:24:27 +00:00
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
2023-07-19 01:51:30 +00:00
}
fn(api.ProgressResponse{Status: "pushing manifest"})
2023-08-22 01:38:31 +00:00
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
2023-08-22 01:24:42 +00:00
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
2023-11-02 20:10:58 +00:00
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
defer resp.Body.Close()
fn(api.ProgressResponse{Status: "success"})
return nil
}
2024-02-14 19:29:49 +00:00
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
2023-11-14 20:30:34 +00:00
deleteMap := make(map[string]struct{})
if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
2023-11-14 20:30:34 +00:00
deleteMap[l.Digest] = struct{}{}
}
2023-11-14 20:30:34 +00:00
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
2023-08-22 04:56:56 +00:00
}
2023-07-19 01:51:30 +00:00
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []*Layer
2023-07-20 18:18:00 +00:00
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
for _, layer := range layers {
2023-08-15 18:07:19 +00:00
if err := downloadBlob(
ctx,
downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
}); err != nil {
return err
}
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
2023-07-20 18:44:05 +00:00
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil {
2023-07-24 18:53:01 +00:00
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
2023-07-24 18:53:01 +00:00
}
}
2023-07-20 18:44:05 +00:00
return err
}
}
2023-07-19 01:51:30 +00:00
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
2023-07-20 18:18:00 +00:00
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
if noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err
}
}
2023-07-19 01:51:30 +00:00
fn(api.ProgressResponse{Status: "success"})
return nil
}
2024-02-14 19:29:49 +00:00
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
2023-08-22 01:38:31 +00:00
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
2023-08-22 01:24:42 +00:00
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
2023-11-02 20:13:32 +00:00
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
return m, err
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
2023-09-28 17:00:34 +00:00
func GetSHA256Digest(r io.Reader) (string, int64) {
2023-07-19 00:14:12 +00:00
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}
2023-09-28 17:00:34 +00:00
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string {
parts := strings.Split(token, ".")
if len(parts) != 3 {
slog.Error("jwt token does not contain 3 parts")
return ""
}
payload := parts[1]
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
return ""
}
var payloadMap map[string]interface{}
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
return ""
}
sub, ok := payloadMap["sub"]
if !ok {
slog.Error("jwt does not contain 'sub' field")
return ""
}
return fmt.Sprintf("%s", sub)
}
2024-02-14 19:29:49 +00:00
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
2024-02-14 19:29:49 +00:00
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
2023-08-17 19:35:29 +00:00
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err))
}
2023-08-17 19:35:29 +00:00
return nil, err
}
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
2024-02-14 19:29:49 +00:00
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
2023-08-17 19:35:29 +00:00
if err != nil {
return nil, err
}
anonymous = getTokenSubject(token) == "anonymous"
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
}
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
2023-08-17 19:35:29 +00:00
}
}
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized
2023-08-17 19:35:29 +00:00
}
2024-02-14 19:29:49 +00:00
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
resp, err := http.DefaultClient.Do(req)
2024-02-14 19:29:49 +00:00
if err != nil {
return nil, err
}
return resp, nil
}
2023-08-10 18:34:25 +00:00
func getValue(header, key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
return ""
}
// Move the index to the starting quote after the key.
startIdx += len(key) + 2
endIdx := startIdx
for endIdx < len(header) {
if header[endIdx] == '"' {
if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
endIdx++
continue
}
break
}
endIdx++
}
return header[startIdx:endIdx]
}
2024-02-14 19:29:49 +00:00
func parseRegistryChallenge(authStr string) registryChallenge {
2023-08-10 18:34:25 +00:00
authStr = strings.TrimPrefix(authStr, "Bearer ")
2024-02-14 19:29:49 +00:00
return registryChallenge{
2023-08-10 18:34:25 +00:00
Realm: getValue(authStr, "realm"),
Service: getValue(authStr, "service"),
Scope: getValue(authStr, "scope"),
}
}
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
2023-07-24 18:53:01 +00:00
2023-07-20 18:44:05 +00:00
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
2023-07-24 18:53:01 +00:00
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
2023-07-20 18:44:05 +00:00
}
return nil
}