1336 lines
31 KiB
Go
1336 lines
31 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"text/template"
|
|
"text/template/parse"
|
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
"github.com/jmorganca/ollama/llm"
|
|
"github.com/jmorganca/ollama/parser"
|
|
"github.com/jmorganca/ollama/version"
|
|
)
|
|
|
|
type RegistryOptions struct {
|
|
Insecure bool
|
|
Username string
|
|
Password string
|
|
Token string
|
|
}
|
|
|
|
type Model struct {
|
|
Name string `json:"name"`
|
|
Config ConfigV2
|
|
ShortName string
|
|
ModelPath string
|
|
ParentModel string
|
|
AdapterPaths []string
|
|
ProjectorPaths []string
|
|
Template string
|
|
System string
|
|
License []string
|
|
Digest string
|
|
Size int64
|
|
Options map[string]interface{}
|
|
Messages []Message
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type PromptVars struct {
|
|
System string
|
|
Prompt string
|
|
Response string
|
|
First bool
|
|
Images []llm.ImageData
|
|
}
|
|
|
|
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
|
func extractParts(tmplStr string) (pre string, post string, err error) {
|
|
tmpl, err := template.New("").Parse(tmplStr)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
var foundResponse bool
|
|
|
|
for _, node := range tmpl.Tree.Root.Nodes {
|
|
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
|
|
foundResponse = true
|
|
}
|
|
if !foundResponse {
|
|
pre += node.String()
|
|
} else {
|
|
post += node.String()
|
|
}
|
|
}
|
|
|
|
return pre, post, nil
|
|
}
|
|
|
|
func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
|
var prompt strings.Builder
|
|
// Use the "missingkey=zero" option to handle missing variables without panicking
|
|
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
vars := map[string]any{
|
|
"System": p.System,
|
|
"Prompt": p.Prompt,
|
|
"Response": p.Response,
|
|
"First": p.First,
|
|
}
|
|
|
|
var sb strings.Builder
|
|
if err := tmpl.Execute(&sb, vars); err != nil {
|
|
return "", err
|
|
}
|
|
prompt.WriteString(sb.String())
|
|
|
|
if !strings.Contains(prompt.String(), p.Response) {
|
|
// if the response is not in the prompt template, append it to the end
|
|
prompt.WriteString(p.Response)
|
|
}
|
|
|
|
return prompt.String(), nil
|
|
}
|
|
|
|
// PreResponsePrompt returns the prompt before the response tag
|
|
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
|
pre, _, err := extractParts(m.Template)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return Prompt(pre, p)
|
|
}
|
|
|
|
// PostResponseTemplate returns the template after the response tag
|
|
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
|
if p.System == "" {
|
|
// use the default system prompt for this model if one is not specified
|
|
p.System = m.System
|
|
}
|
|
_, post, err := extractParts(m.Template)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if post == "" {
|
|
// if there is no post-response template, return the provided response
|
|
return p.Response, nil
|
|
}
|
|
|
|
return Prompt(post, p)
|
|
}
|
|
|
|
type ChatHistory struct {
|
|
Prompts []PromptVars
|
|
LastSystem string
|
|
}
|
|
|
|
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
|
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|
// build the prompt from the list of messages
|
|
lastSystem := m.System
|
|
currentVars := PromptVars{
|
|
First: true,
|
|
System: m.System,
|
|
}
|
|
|
|
prompts := []PromptVars{}
|
|
var images []llm.ImageData
|
|
|
|
for _, msg := range msgs {
|
|
switch strings.ToLower(msg.Role) {
|
|
case "system":
|
|
// if this is the first message it overrides the system prompt in the modelfile
|
|
if !currentVars.First && currentVars.System != "" {
|
|
prompts = append(prompts, currentVars)
|
|
currentVars = PromptVars{}
|
|
}
|
|
currentVars.System = msg.Content
|
|
lastSystem = msg.Content
|
|
case "user":
|
|
if currentVars.Prompt != "" {
|
|
prompts = append(prompts, currentVars)
|
|
currentVars = PromptVars{}
|
|
}
|
|
|
|
currentVars.Prompt = msg.Content
|
|
|
|
if len(m.ProjectorPaths) > 0 {
|
|
for i := range msg.Images {
|
|
id := len(images) + i
|
|
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
|
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
|
ID: id,
|
|
Data: msg.Images[i],
|
|
})
|
|
}
|
|
|
|
images = append(images, currentVars.Images...)
|
|
}
|
|
case "assistant":
|
|
currentVars.Response = msg.Content
|
|
prompts = append(prompts, currentVars)
|
|
currentVars = PromptVars{}
|
|
default:
|
|
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
|
}
|
|
}
|
|
|
|
// Append the last set of vars if they are non-empty
|
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
|
prompts = append(prompts, currentVars)
|
|
}
|
|
|
|
return &ChatHistory{
|
|
Prompts: prompts,
|
|
LastSystem: lastSystem,
|
|
}, nil
|
|
}
|
|
|
|
type ManifestV2 struct {
|
|
SchemaVersion int `json:"schemaVersion"`
|
|
MediaType string `json:"mediaType"`
|
|
Config *Layer `json:"config"`
|
|
Layers []*Layer `json:"layers"`
|
|
}
|
|
|
|
type ConfigV2 struct {
|
|
ModelFormat string `json:"model_format"`
|
|
ModelFamily string `json:"model_family"`
|
|
ModelFamilies []string `json:"model_families"`
|
|
ModelType string `json:"model_type"`
|
|
FileType string `json:"file_type"`
|
|
|
|
// required by spec
|
|
Architecture string `json:"architecture"`
|
|
OS string `json:"os"`
|
|
RootFS RootFS `json:"rootfs"`
|
|
}
|
|
|
|
func (c *ConfigV2) SetModelFormat(format string) {
|
|
if c.ModelFormat == "" {
|
|
c.ModelFormat = format
|
|
}
|
|
}
|
|
|
|
func (c *ConfigV2) SetModelFamily(families ...string) {
|
|
for _, family := range families {
|
|
if c.ModelFamily == "" {
|
|
c.ModelFamily = family
|
|
}
|
|
|
|
if !slices.Contains(c.ModelFamilies, family) {
|
|
c.ModelFamilies = append(c.ModelFamilies, family)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ConfigV2) SetModelType(modelType string) {
|
|
if c.ModelType == "" {
|
|
c.ModelType = modelType
|
|
}
|
|
}
|
|
|
|
func (c *ConfigV2) SetFileType(fileType string) {
|
|
if c.FileType == "" {
|
|
c.FileType = fileType
|
|
}
|
|
}
|
|
|
|
type RootFS struct {
|
|
Type string `json:"type"`
|
|
DiffIDs []string `json:"diff_ids"`
|
|
}
|
|
|
|
func (m *ManifestV2) GetTotalSize() (total int64) {
|
|
for _, layer := range m.Layers {
|
|
total += layer.Size
|
|
}
|
|
|
|
total += m.Config.Size
|
|
return total
|
|
}
|
|
|
|
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
|
fp, err := mp.GetManifestPath()
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
if _, err = os.Stat(fp); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
var manifest *ManifestV2
|
|
|
|
bts, err := os.ReadFile(fp)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
|
|
}
|
|
|
|
shaSum := sha256.Sum256(bts)
|
|
shaStr := hex.EncodeToString(shaSum[:])
|
|
|
|
if err := json.Unmarshal(bts, &manifest); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return manifest, shaStr, nil
|
|
}
|
|
|
|
func GetModel(name string) (*Model, error) {
|
|
mp := ParseModelPath(name)
|
|
manifest, digest, err := GetManifest(mp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model := &Model{
|
|
Name: mp.GetFullTagname(),
|
|
ShortName: mp.GetShortTagname(),
|
|
Digest: digest,
|
|
Template: "{{ .Prompt }}",
|
|
License: []string{},
|
|
Size: manifest.GetTotalSize(),
|
|
}
|
|
|
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
configFile, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer configFile.Close()
|
|
|
|
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, layer := range manifest.Layers {
|
|
filename, err := GetBlobsPath(layer.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch layer.MediaType {
|
|
case "application/vnd.ollama.image.model":
|
|
model.ModelPath = filename
|
|
model.ParentModel = layer.From
|
|
case "application/vnd.ollama.image.embed":
|
|
// Deprecated in versions > 0.1.2
|
|
// TODO: remove this warning in a future version
|
|
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
|
case "application/vnd.ollama.image.adapter":
|
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
|
case "application/vnd.ollama.image.projector":
|
|
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
|
case "application/vnd.ollama.image.template":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.Template = string(bts)
|
|
case "application/vnd.ollama.image.system":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.System = string(bts)
|
|
case "application/vnd.ollama.image.prompt":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.Template = string(bts)
|
|
case "application/vnd.ollama.image.params":
|
|
params, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer params.Close()
|
|
|
|
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
|
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
|
return nil, err
|
|
}
|
|
case "application/vnd.ollama.image.messages":
|
|
msgs, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer msgs.Close()
|
|
|
|
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
|
return nil, err
|
|
}
|
|
case "application/vnd.ollama.image.license":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
model.License = append(model.License, string(bts))
|
|
}
|
|
}
|
|
|
|
return model, nil
|
|
}
|
|
|
|
func realpath(mfDir, from string) string {
|
|
abspath, err := filepath.Abs(from)
|
|
if err != nil {
|
|
return from
|
|
}
|
|
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return abspath
|
|
}
|
|
|
|
if from == "~" {
|
|
return home
|
|
} else if strings.HasPrefix(from, "~/") {
|
|
return filepath.Join(home, from[2:])
|
|
}
|
|
|
|
if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil {
|
|
// this is a file relative to the Modelfile
|
|
return filepath.Join(mfDir, from)
|
|
}
|
|
|
|
return abspath
|
|
}
|
|
|
|
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
|
deleteMap := make(map[string]struct{})
|
|
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
|
for _, layer := range append(manifest.Layers, manifest.Config) {
|
|
deleteMap[layer.Digest] = struct{}{}
|
|
}
|
|
}
|
|
|
|
config := ConfigV2{
|
|
OS: "linux",
|
|
Architecture: "amd64",
|
|
RootFS: RootFS{
|
|
Type: "layers",
|
|
},
|
|
}
|
|
|
|
var layers Layers
|
|
messages := []string{}
|
|
|
|
params := make(map[string][]string)
|
|
fromParams := make(map[string]any)
|
|
|
|
for _, c := range commands {
|
|
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
|
|
|
switch c.Name {
|
|
case "model":
|
|
if strings.HasPrefix(c.Args, "@") {
|
|
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.Args = blobPath
|
|
}
|
|
|
|
bin, err := os.Open(realpath(modelFileDir, c.Args))
|
|
if err != nil {
|
|
// not a file on disk so must be a model reference
|
|
modelpath := ParseModelPath(c.Args)
|
|
manifest, _, err := GetManifest(modelpath)
|
|
switch {
|
|
case errors.Is(err, os.ErrNotExist):
|
|
fn(api.ProgressResponse{Status: "pulling model"})
|
|
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
|
return err
|
|
}
|
|
|
|
manifest, _, err = GetManifest(modelpath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case err != nil:
|
|
return err
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "reading model metadata"})
|
|
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fromConfigFile, err := os.Open(fromConfigPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer fromConfigFile.Close()
|
|
|
|
var fromConfig ConfigV2
|
|
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
|
|
return err
|
|
}
|
|
|
|
// if the model is still not in gguf format, error out
|
|
if fromConfig.ModelFormat != "gguf" {
|
|
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
|
|
}
|
|
|
|
config.SetModelFormat(fromConfig.ModelFormat)
|
|
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
|
|
config.SetModelType(fromConfig.ModelType)
|
|
config.SetFileType(fromConfig.FileType)
|
|
|
|
for _, layer := range manifest.Layers {
|
|
deleteMap[layer.Digest] = struct{}{}
|
|
if layer.MediaType == "application/vnd.ollama.image.params" {
|
|
fromParamsPath, err := GetBlobsPath(layer.Digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fromParamsFile, err := os.Open(fromParamsPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer fromParamsFile.Close()
|
|
|
|
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Add(layer)
|
|
}
|
|
|
|
deleteMap[manifest.Config.Digest] = struct{}{}
|
|
continue
|
|
}
|
|
defer bin.Close()
|
|
|
|
var offset int64
|
|
CREATE:
|
|
for {
|
|
fn(api.ProgressResponse{Status: "creating model layer"})
|
|
|
|
bin.Seek(offset, io.SeekStart)
|
|
ggml, err := llm.DecodeGGML(bin)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
break CREATE
|
|
case errors.Is(err, llm.ErrUnsupportedFormat):
|
|
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
config.SetModelFormat(ggml.Name())
|
|
config.SetModelFamily(ggml.ModelFamily())
|
|
config.SetModelType(ggml.ModelType())
|
|
config.SetFileType(ggml.FileType())
|
|
|
|
mediatype := mediatype
|
|
if ggml.ModelFamily() == "clip" {
|
|
mediatype = "application/vnd.ollama.image.projector"
|
|
}
|
|
|
|
sr := io.NewSectionReader(bin, offset, ggml.Size)
|
|
layer, err := NewLayer(sr, mediatype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Add(layer)
|
|
|
|
offset += ggml.Size
|
|
}
|
|
case "adapter":
|
|
if strings.HasPrefix(c.Args, "@") {
|
|
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.Args = blobPath
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "creating adapter layer"})
|
|
bin, err := os.Open(realpath(modelFileDir, c.Args))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer bin.Close()
|
|
|
|
layer, err := NewLayer(bin, mediatype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Add(layer)
|
|
case "license":
|
|
fn(api.ProgressResponse{Status: "creating license layer"})
|
|
|
|
bin := strings.NewReader(c.Args)
|
|
layer, err := NewLayer(bin, mediatype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Add(layer)
|
|
case "template", "system":
|
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
|
|
|
|
bin := strings.NewReader(c.Args)
|
|
layer, err := NewLayer(bin, mediatype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Replace(layer)
|
|
case "message":
|
|
messages = append(messages, c.Args)
|
|
default:
|
|
params[c.Name] = append(params[c.Name], c.Args)
|
|
}
|
|
}
|
|
|
|
if len(messages) > 0 {
|
|
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
|
|
|
msgs := make([]api.Message, 0)
|
|
|
|
for _, m := range messages {
|
|
// todo: handle images
|
|
msg := strings.SplitN(m, ": ", 2)
|
|
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
|
|
return err
|
|
}
|
|
|
|
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Replace(layer)
|
|
}
|
|
|
|
if len(params) > 0 {
|
|
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
|
|
|
formattedParams, err := api.FormatParams(params)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for k, v := range fromParams {
|
|
if _, ok := formattedParams[k]; !ok {
|
|
formattedParams[k] = v
|
|
}
|
|
}
|
|
|
|
// xxx - can this be removed?
|
|
if config.ModelType == "65B" {
|
|
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
|
|
config.ModelType = "70B"
|
|
}
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
|
|
return err
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "creating config layer"})
|
|
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers.Replace(layer)
|
|
}
|
|
|
|
digests := make([]string, len(layers.items))
|
|
for i, layer := range layers.items {
|
|
digests[i] = layer.Digest
|
|
}
|
|
|
|
config.RootFS.DiffIDs = digests
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
|
return err
|
|
}
|
|
|
|
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
delete(deleteMap, configLayer.Digest)
|
|
|
|
for _, layer := range append(layers.items, configLayer) {
|
|
committed, err := layer.Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
status := "writing layer"
|
|
if !committed {
|
|
status = "using already created layer"
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
|
|
|
|
delete(deleteMap, layer.Digest)
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
|
if err := WriteManifest(name, configLayer, layers.items); err != nil {
|
|
return err
|
|
}
|
|
|
|
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
|
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
return nil
|
|
}
|
|
|
|
func CopyModel(src, dest string) error {
|
|
srcModelPath := ParseModelPath(src)
|
|
srcPath, err := srcModelPath.GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
destModelPath := ParseModelPath(dest)
|
|
destPath, err := destModelPath.GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
// copy the file
|
|
input, err := os.ReadFile(srcPath)
|
|
if err != nil {
|
|
fmt.Println("Error reading file:", err)
|
|
return err
|
|
}
|
|
|
|
err = os.WriteFile(destPath, input, 0o644)
|
|
if err != nil {
|
|
fmt.Println("Error reading file:", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
|
|
fp, err := GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
walkFunc := func(path string, info os.FileInfo, _ error) error {
|
|
if info.IsDir() {
|
|
return nil
|
|
}
|
|
|
|
dir, file := filepath.Split(path)
|
|
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
|
|
tag := strings.Join([]string{dir, file}, ":")
|
|
fmp := ParseModelPath(tag)
|
|
|
|
// skip the manifest we're trying to delete
|
|
if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
|
|
return nil
|
|
}
|
|
|
|
// save (i.e. delete from the deleteMap) any files used in other manifests
|
|
manifest, _, err := GetManifest(fmp)
|
|
if err != nil {
|
|
// nolint: nilerr
|
|
return nil
|
|
}
|
|
|
|
for _, layer := range manifest.Layers {
|
|
delete(deleteMap, layer.Digest)
|
|
}
|
|
|
|
delete(deleteMap, manifest.Config.Digest)
|
|
return nil
|
|
}
|
|
|
|
if err := filepath.Walk(fp, walkFunc); err != nil {
|
|
return err
|
|
}
|
|
|
|
// only delete the files which are still in the deleteMap
|
|
for k := range deleteMap {
|
|
fp, err := GetBlobsPath(k)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
|
continue
|
|
}
|
|
if !dryRun {
|
|
if err := os.Remove(fp); err != nil {
|
|
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
|
|
continue
|
|
}
|
|
} else {
|
|
slog.Info(fmt.Sprintf("wanted to remove: %s", fp))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func PruneLayers() error {
|
|
deleteMap := make(map[string]struct{})
|
|
p, err := GetBlobsPath("")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
blobs, err := os.ReadDir(p)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
|
|
return err
|
|
}
|
|
|
|
for _, blob := range blobs {
|
|
name := blob.Name()
|
|
if runtime.GOOS == "windows" {
|
|
name = strings.ReplaceAll(name, "-", ":")
|
|
}
|
|
if strings.HasPrefix(name, "sha256:") {
|
|
deleteMap[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
|
|
|
|
err = deleteUnusedLayers(nil, deleteMap, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
|
|
|
|
return nil
|
|
}
|
|
|
|
func PruneDirectory(path string) error {
|
|
info, err := os.Lstat(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
|
entries, err := os.ReadDir(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
entries, err = os.ReadDir(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(entries) > 0 {
|
|
return nil
|
|
}
|
|
|
|
return os.Remove(path)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func DeleteModel(name string) error {
|
|
mp := ParseModelPath(name)
|
|
manifest, _, err := GetManifest(mp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
deleteMap := make(map[string]struct{})
|
|
for _, layer := range manifest.Layers {
|
|
deleteMap[layer.Digest] = struct{}{}
|
|
}
|
|
deleteMap[manifest.Config.Digest] = struct{}{}
|
|
|
|
err = deleteUnusedLayers(&mp, deleteMap, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fp, err := mp.GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = os.Remove(fp)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ShowModelfile(model *Model) (string, error) {
|
|
var mt struct {
|
|
*Model
|
|
From string
|
|
Parameters map[string][]any
|
|
}
|
|
|
|
mt.Parameters = make(map[string][]any)
|
|
for k, v := range model.Options {
|
|
if s, ok := v.([]any); ok {
|
|
mt.Parameters[k] = s
|
|
continue
|
|
}
|
|
|
|
mt.Parameters[k] = []any{v}
|
|
}
|
|
|
|
mt.Model = model
|
|
mt.From = model.ModelPath
|
|
|
|
if model.ParentModel != "" {
|
|
mt.From = model.ParentModel
|
|
}
|
|
|
|
modelFile := `# Modelfile generated by "ollama show"
|
|
# To build a new Modelfile based on this one, replace the FROM line with:
|
|
# FROM {{ .ShortName }}
|
|
|
|
FROM {{ .From }}
|
|
TEMPLATE """{{ .Template }}"""
|
|
|
|
{{- if .System }}
|
|
SYSTEM """{{ .System }}"""
|
|
{{- end }}
|
|
|
|
{{- range $adapter := .AdapterPaths }}
|
|
ADAPTER {{ $adapter }}
|
|
{{- end }}
|
|
|
|
{{- range $k, $v := .Parameters }}
|
|
{{- range $parameter := $v }}
|
|
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
|
{{- end }}
|
|
{{- end }}`
|
|
|
|
tmpl, err := template.New("").Parse(modelFile)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("error parsing template: %q", err))
|
|
return "", err
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
|
|
if err = tmpl.Execute(&buf, mt); err != nil {
|
|
slog.Info(fmt.Sprintf("error executing template: %q", err))
|
|
return "", err
|
|
}
|
|
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
mp := ParseModelPath(name)
|
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
|
|
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
|
return fmt.Errorf("insecure protocol http")
|
|
}
|
|
|
|
manifest, _, err := GetManifest(mp)
|
|
if err != nil {
|
|
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
|
return err
|
|
}
|
|
|
|
var layers []*Layer
|
|
layers = append(layers, manifest.Layers...)
|
|
layers = append(layers, manifest.Config)
|
|
|
|
for _, layer := range layers {
|
|
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
|
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
|
if errors.Is(err, errUnauthorized) {
|
|
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "pushing manifest"})
|
|
requestURL := mp.BaseURL()
|
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
|
|
|
manifestJSON, err := json.Marshal(manifest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
headers := make(http.Header)
|
|
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
|
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
|
|
return nil
|
|
}
|
|
|
|
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
mp := ParseModelPath(name)
|
|
|
|
var manifest *ManifestV2
|
|
var err error
|
|
var noprune string
|
|
|
|
// build deleteMap to prune unused layers
|
|
deleteMap := make(map[string]struct{})
|
|
|
|
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
|
manifest, _, err = GetManifest(mp)
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
|
|
if manifest != nil {
|
|
for _, l := range manifest.Layers {
|
|
deleteMap[l.Digest] = struct{}{}
|
|
}
|
|
deleteMap[manifest.Config.Digest] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
|
return fmt.Errorf("insecure protocol http")
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
|
|
|
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
|
if err != nil {
|
|
return fmt.Errorf("pull model manifest: %s", err)
|
|
}
|
|
|
|
var layers []*Layer
|
|
layers = append(layers, manifest.Layers...)
|
|
layers = append(layers, manifest.Config)
|
|
|
|
for _, layer := range layers {
|
|
if err := downloadBlob(
|
|
ctx,
|
|
downloadOpts{
|
|
mp: mp,
|
|
digest: layer.Digest,
|
|
regOpts: regOpts,
|
|
fn: fn,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
delete(deleteMap, layer.Digest)
|
|
}
|
|
delete(deleteMap, manifest.Config.Digest)
|
|
|
|
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
|
for _, layer := range layers {
|
|
if err := verifyBlob(layer.Digest); err != nil {
|
|
if errors.Is(err, errDigestMismatch) {
|
|
// something went wrong, delete the blob
|
|
fp, err := GetBlobsPath(layer.Digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.Remove(fp); err != nil {
|
|
// log this, but return the original error
|
|
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
|
|
|
manifestJSON, err := json.Marshal(manifest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fp, err := mp.GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = os.WriteFile(fp, manifestJSON, 0o644)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
|
|
return err
|
|
}
|
|
|
|
if noprune == "" {
|
|
fn(api.ProgressResponse{Status: "removing any unused layers"})
|
|
err = deleteUnusedLayers(nil, deleteMap, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
|
|
return nil
|
|
}
|
|
|
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
|
|
|
headers := make(http.Header)
|
|
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var m *ManifestV2
|
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m, err
|
|
}
|
|
|
|
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
|
|
func GetSHA256Digest(r io.Reader) (string, int64) {
|
|
h := sha256.New()
|
|
n, err := io.Copy(h, r)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
|
}
|
|
|
|
var errUnauthorized = fmt.Errorf("unauthorized")
|
|
|
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
|
for i := 0; i < 2; i++ {
|
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
|
if err != nil {
|
|
if !errors.Is(err, context.Canceled) {
|
|
slog.Info(fmt.Sprintf("request failed: %v", err))
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
switch {
|
|
case resp.StatusCode == http.StatusUnauthorized:
|
|
// Handle authentication error with one retry
|
|
auth := resp.Header.Get("www-authenticate")
|
|
authRedir := ParseAuthRedirectString(auth)
|
|
token, err := getAuthToken(ctx, authRedir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
regOpts.Token = token
|
|
if body != nil {
|
|
_, err = body.Seek(0, io.SeekStart)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
case resp.StatusCode == http.StatusNotFound:
|
|
return nil, os.ErrNotExist
|
|
case resp.StatusCode >= http.StatusBadRequest:
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
|
}
|
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
|
default:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
return nil, errUnauthorized
|
|
}
|
|
|
|
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
|
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
|
requestURL.Scheme = "http"
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if headers != nil {
|
|
req.Header = headers
|
|
}
|
|
|
|
if regOpts != nil {
|
|
if regOpts.Token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
|
} else if regOpts.Username != "" && regOpts.Password != "" {
|
|
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
|
}
|
|
}
|
|
|
|
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
|
|
|
if s := req.Header.Get("Content-Length"); s != "" {
|
|
contentLength, err := strconv.ParseInt(s, 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.ContentLength = contentLength
|
|
}
|
|
|
|
proxyURL, err := http.ProxyFromEnvironment(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyURL),
|
|
},
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func getValue(header, key string) string {
|
|
startIdx := strings.Index(header, key+"=")
|
|
if startIdx == -1 {
|
|
return ""
|
|
}
|
|
|
|
// Move the index to the starting quote after the key.
|
|
startIdx += len(key) + 2
|
|
endIdx := startIdx
|
|
|
|
for endIdx < len(header) {
|
|
if header[endIdx] == '"' {
|
|
if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
|
|
endIdx++
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
endIdx++
|
|
}
|
|
return header[startIdx:endIdx]
|
|
}
|
|
|
|
func ParseAuthRedirectString(authStr string) AuthRedirect {
|
|
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
|
|
|
return AuthRedirect{
|
|
Realm: getValue(authStr, "realm"),
|
|
Service: getValue(authStr, "service"),
|
|
Scope: getValue(authStr, "scope"),
|
|
}
|
|
}
|
|
|
|
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
|
|
|
|
func verifyBlob(digest string) error {
|
|
fp, err := GetBlobsPath(digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
f, err := os.Open(fp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
fileDigest, _ := GetSHA256Digest(f)
|
|
if digest != fileDigest {
|
|
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
|
|
}
|
|
|
|
return nil
|
|
}
|