(false)
- const command = 'ollama run llama2'
+ const command = 'ollama run llama3'
return (
diff --git a/parser/parser.go b/parser/parser.go
deleted file mode 100644
index 947848b2..00000000
--- a/parser/parser.go
+++ /dev/null
@@ -1,132 +0,0 @@
-package parser
-
-import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "slices"
-)
-
-type Command struct {
- Name string
- Args string
-}
-
-func (c *Command) Reset() {
- c.Name = ""
- c.Args = ""
-}
-
-func Parse(reader io.Reader) ([]Command, error) {
- var commands []Command
- var command, modelCommand Command
-
- scanner := bufio.NewScanner(reader)
- scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
- scanner.Split(scanModelfile)
- for scanner.Scan() {
- line := scanner.Bytes()
-
- fields := bytes.SplitN(line, []byte(" "), 2)
- if len(fields) == 0 || len(fields[0]) == 0 {
- continue
- }
-
- switch string(bytes.ToUpper(fields[0])) {
- case "FROM":
- command.Name = "model"
- command.Args = string(bytes.TrimSpace(fields[1]))
- // copy command for validation
- modelCommand = command
- case "ADAPTER":
- command.Name = string(bytes.ToLower(fields[0]))
- command.Args = string(bytes.TrimSpace(fields[1]))
- case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
- command.Name = string(bytes.ToLower(fields[0]))
- command.Args = string(fields[1])
- case "PARAMETER":
- fields = bytes.SplitN(fields[1], []byte(" "), 2)
- if len(fields) < 2 {
- return nil, fmt.Errorf("missing value for %s", fields)
- }
-
- command.Name = string(fields[0])
- command.Args = string(bytes.TrimSpace(fields[1]))
- case "EMBED":
- return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
- case "MESSAGE":
- command.Name = string(bytes.ToLower(fields[0]))
- fields = bytes.SplitN(fields[1], []byte(" "), 2)
- if len(fields) < 2 {
- return nil, fmt.Errorf("should be in the format ")
- }
- if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
- return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
- }
- command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
- default:
- if !bytes.HasPrefix(fields[0], []byte("#")) {
- // log a warning for unknown commands
- slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
- }
- continue
- }
-
- commands = append(commands, command)
- command.Reset()
- }
-
- if modelCommand.Args == "" {
- return nil, errors.New("no FROM line for the model was specified")
- }
-
- return commands, scanner.Err()
-}
-
-func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
- advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
- if err != nil {
- return 0, nil, err
- }
-
- if advance > 0 && token != nil {
- return advance, token, nil
- }
-
- advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
- if err != nil {
- return 0, nil, err
- }
-
- if advance > 0 && token != nil {
- return advance, token, nil
- }
-
- return bufio.ScanLines(data, atEOF)
-}
-
-func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
- newline := bytes.IndexByte(data, '\n')
-
- if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
- end := bytes.Index(data[start+len(openBytes):], closeBytes)
- if end < 0 {
- if atEOF {
- return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
- } else {
- return 0, nil, nil
- }
- }
-
- n := start + len(openBytes) + end + len(closeBytes)
-
- newData := data[:start]
- newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
- return n, newData, nil
- }
-
- return 0, nil, nil
-}
diff --git a/parser/parser_test.go b/parser/parser_test.go
deleted file mode 100644
index 25e849b5..00000000
--- a/parser/parser_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package parser
-
-import (
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func Test_Parser(t *testing.T) {
-
- input := `
-FROM model1
-ADAPTER adapter1
-LICENSE MIT
-PARAMETER param1 value1
-PARAMETER param2 value2
-TEMPLATE template1
-`
-
- reader := strings.NewReader(input)
-
- commands, err := Parse(reader)
- assert.Nil(t, err)
-
- expectedCommands := []Command{
- {Name: "model", Args: "model1"},
- {Name: "adapter", Args: "adapter1"},
- {Name: "license", Args: "MIT"},
- {Name: "param1", Args: "value1"},
- {Name: "param2", Args: "value2"},
- {Name: "template", Args: "template1"},
- }
-
- assert.Equal(t, expectedCommands, commands)
-}
-
-func Test_Parser_NoFromLine(t *testing.T) {
-
- input := `
-PARAMETER param1 value1
-PARAMETER param2 value2
-`
-
- reader := strings.NewReader(input)
-
- _, err := Parse(reader)
- assert.ErrorContains(t, err, "no FROM line")
-}
-
-func Test_Parser_MissingValue(t *testing.T) {
-
- input := `
-FROM foo
-PARAMETER param1
-`
-
- reader := strings.NewReader(input)
-
- _, err := Parse(reader)
- assert.ErrorContains(t, err, "missing value for [param1]")
-
-}
-
-func Test_Parser_Messages(t *testing.T) {
-
- input := `
-FROM foo
-MESSAGE system You are a Parser. Always Parse things.
-MESSAGE user Hey there!
-MESSAGE assistant Hello, I want to parse all the things!
-`
-
- reader := strings.NewReader(input)
- commands, err := Parse(reader)
- assert.Nil(t, err)
-
- expectedCommands := []Command{
- {Name: "model", Args: "foo"},
- {Name: "message", Args: "system: You are a Parser. Always Parse things."},
- {Name: "message", Args: "user: Hey there!"},
- {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
- }
-
- assert.Equal(t, expectedCommands, commands)
-}
-
-func Test_Parser_Messages_BadRole(t *testing.T) {
-
- input := `
-FROM foo
-MESSAGE badguy I'm a bad guy!
-`
-
- reader := strings.NewReader(input)
- _, err := Parse(reader)
- assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
-}
diff --git a/readline/readline.go b/readline/readline.go
index 8ba7d89c..6fa45391 100644
--- a/readline/readline.go
+++ b/readline/readline.go
@@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := int(syscall.Stdin)
return handleCharCtrlZ(fd, i.Terminal.termios)
- case CharEnter:
+ case CharEnter, CharCtrlJ:
output := buf.String()
if output != "" {
i.History.Add([]rune(output))
@@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) {
metaDel = false
continue
}
- if r >= CharSpace || r == CharEnter {
+ if r >= CharSpace || r == CharEnter || r == CharCtrlJ {
buf.Add(r)
}
}
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 1a89045a..60de0307 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -7,6 +7,8 @@
$ErrorActionPreference = "Stop"
function checkEnv() {
+ $script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
+ Write-host "Building for ${script:TARGET_ARCH}"
write-host "Locating required tools and paths"
$script:SRC_DIR=$PWD
if (!$env:VCToolsRedistDir) {
@@ -30,7 +32,7 @@ function checkEnv() {
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
- $script:DEPS_DIR="${script:SRC_DIR}\dist\windeps"
+ $script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
echo "Checking version"
if (!$env:VERSION) {
@@ -81,8 +83,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
- New-Item -ItemType Directory -Path .\dist -Force
- cp .\ollama.exe .\dist\ollama-windows-amd64.exe
+ New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
+ cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
}
function buildApp() {
@@ -101,7 +103,6 @@ function buildApp() {
function gatherDependencies() {
write-host "Gathering runtime dependencies"
cd "${script:SRC_DIR}"
- rm -ea 0 -recurse -force -path "${script:DEPS_DIR}"
md "${script:DEPS_DIR}" -ea 0 > $null
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
@@ -110,9 +111,6 @@ function gatherDependencies() {
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\"
- cp "${script:NVIDIA_DIR}\cudart64_*.dll" "${script:DEPS_DIR}\"
- cp "${script:NVIDIA_DIR}\cublas64_*.dll" "${script:DEPS_DIR}\"
- cp "${script:NVIDIA_DIR}\cublasLt64_*.dll" "${script:DEPS_DIR}\"
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
if ("${env:KEY_CONTAINER}") {
@@ -124,7 +122,6 @@ function gatherDependencies() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
-
}
function buildInstaller() {
@@ -132,19 +129,25 @@ function buildInstaller() {
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
- & "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
+ & "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
} else {
- & "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
+ & "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
}
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
+function distZip() {
+ write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
+ Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
+}
+
try {
checkEnv
buildOllama
buildApp
gatherDependencies
buildInstaller
+ distZip
} catch {
write-host "Build Failed"
write-host $_
diff --git a/scripts/install.sh b/scripts/install.sh
index eb3ff504..20b0db60 100644
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -166,8 +166,8 @@ fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
# Look for pre-existing ROCm v6 before downloading the dependencies
- for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm"; do
- if [ -n "${search}" ] && [ -e "${search}/lib/libhipblas.so.2" ]; then
+ for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
+ if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
status "Compatible AMD GPU ROCm library detected at ${search}"
install_success
exit 0
diff --git a/server/envconfig/config.go b/server/envconfig/config.go
new file mode 100644
index 00000000..9ad68180
--- /dev/null
+++ b/server/envconfig/config.go
@@ -0,0 +1,174 @@
+package envconfig
+
+import (
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+)
+
+var (
+ // Set via OLLAMA_ORIGINS in the environment
+ AllowOrigins []string
+ // Set via OLLAMA_DEBUG in the environment
+ Debug bool
+ // Set via OLLAMA_LLM_LIBRARY in the environment
+ LLMLibrary string
+ // Set via OLLAMA_MAX_LOADED_MODELS in the environment
+ MaxRunners int
+ // Set via OLLAMA_MAX_QUEUE in the environment
+ MaxQueuedRequests int
+ // Set via OLLAMA_MAX_VRAM in the environment
+ MaxVRAM uint64
+ // Set via OLLAMA_NOPRUNE in the environment
+ NoPrune bool
+ // Set via OLLAMA_NUM_PARALLEL in the environment
+ NumParallel int
+ // Set via OLLAMA_RUNNERS_DIR in the environment
+ RunnersDir string
+ // Set via OLLAMA_TMPDIR in the environment
+ TmpDir string
+)
+
+func AsMap() map[string]string {
+ return map[string]string{
+ "OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
+ "OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
+ "OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
+ "OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
+ "OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
+ "OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
+ "OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
+ "OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
+ "OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
+ "OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
+ }
+}
+
+var defaultAllowOrigins = []string{
+ "localhost",
+ "127.0.0.1",
+ "0.0.0.0",
+}
+
+// Clean quotes and spaces from the value
+func clean(key string) string {
+ return strings.Trim(os.Getenv(key), "\"' ")
+}
+
+func init() {
+ // default values
+ NumParallel = 1
+ MaxRunners = 1
+ MaxQueuedRequests = 512
+
+ LoadConfig()
+}
+
+func LoadConfig() {
+ if debug := clean("OLLAMA_DEBUG"); debug != "" {
+ d, err := strconv.ParseBool(debug)
+ if err == nil {
+ Debug = d
+ } else {
+ Debug = true
+ }
+ }
+
+ RunnersDir = clean("OLLAMA_RUNNERS_DIR")
+ if runtime.GOOS == "windows" && RunnersDir == "" {
+ // On Windows we do not carry the payloads inside the main executable
+ appExe, err := os.Executable()
+ if err != nil {
+ slog.Error("failed to lookup executable path", "error", err)
+ }
+
+ cwd, err := os.Getwd()
+ if err != nil {
+ slog.Error("failed to lookup working directory", "error", err)
+ }
+
+ var paths []string
+ for _, root := range []string{filepath.Dir(appExe), cwd} {
+ paths = append(paths,
+ filepath.Join(root),
+ filepath.Join(root, "windows-"+runtime.GOARCH),
+ filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
+ )
+ }
+
+ // Try a few variations to improve developer experience when building from source in the local tree
+ for _, p := range paths {
+ candidate := filepath.Join(p, "ollama_runners")
+ _, err := os.Stat(candidate)
+ if err == nil {
+ RunnersDir = candidate
+ break
+ }
+ }
+ if RunnersDir == "" {
+ slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
+ }
+ }
+
+ TmpDir = clean("OLLAMA_TMPDIR")
+
+ userLimit := clean("OLLAMA_MAX_VRAM")
+ if userLimit != "" {
+ avail, err := strconv.ParseUint(userLimit, 10, 64)
+ if err != nil {
+ slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
+ } else {
+ MaxVRAM = avail
+ }
+ }
+
+ LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
+
+ if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
+ val, err := strconv.Atoi(onp)
+ if err != nil || val <= 0 {
+ slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
+ } else {
+ NumParallel = val
+ }
+ }
+
+ if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
+ NoPrune = true
+ }
+
+ if origins := clean("OLLAMA_ORIGINS"); origins != "" {
+ AllowOrigins = strings.Split(origins, ",")
+ }
+ for _, allowOrigin := range defaultAllowOrigins {
+ AllowOrigins = append(AllowOrigins,
+ fmt.Sprintf("http://%s", allowOrigin),
+ fmt.Sprintf("https://%s", allowOrigin),
+ fmt.Sprintf("http://%s:*", allowOrigin),
+ fmt.Sprintf("https://%s:*", allowOrigin),
+ )
+ }
+
+ maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
+ if maxRunners != "" {
+ m, err := strconv.Atoi(maxRunners)
+ if err != nil {
+ slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
+ } else {
+ MaxRunners = m
+ }
+ }
+
+ if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
+ p, err := strconv.Atoi(onp)
+ if err != nil || p <= 0 {
+ slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
+ } else {
+ MaxQueuedRequests = p
+ }
+ }
+}
diff --git a/server/envconfig/config_test.go b/server/envconfig/config_test.go
new file mode 100644
index 00000000..b2760299
--- /dev/null
+++ b/server/envconfig/config_test.go
@@ -0,0 +1,20 @@
+package envconfig
+
+import (
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestConfig(t *testing.T) {
+ os.Setenv("OLLAMA_DEBUG", "")
+ LoadConfig()
+ require.False(t, Debug)
+ os.Setenv("OLLAMA_DEBUG", "false")
+ LoadConfig()
+ require.False(t, Debug)
+ os.Setenv("OLLAMA_DEBUG", "1")
+ LoadConfig()
+ require.True(t, Debug)
+}
diff --git a/server/images.go b/server/images.go
index 74fa1a5e..2be1d366 100644
--- a/server/images.go
+++ b/server/images.go
@@ -1,16 +1,16 @@
package server
import (
- "archive/zip"
"bytes"
+ "cmp"
"context"
"crypto/sha256"
+ "encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
- "io/fs"
"log"
"log/slog"
"net/http"
@@ -20,15 +20,16 @@ import (
"runtime"
"strconv"
"strings"
- "text/template"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
- "github.com/ollama/ollama/convert"
+ "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
- "github.com/ollama/ollama/parser"
+ "github.com/ollama/ollama/server/envconfig"
+ "github.com/ollama/ollama/types/errtypes"
+ "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -51,7 +52,6 @@ type Model struct {
System string
License []string
Digest string
- Size int64
Options map[string]interface{}
Messages []Message
}
@@ -60,6 +60,76 @@ func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
}
+func (m *Model) String() string {
+ var modelfile model.File
+
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "model",
+ Args: m.ModelPath,
+ })
+
+ for _, adapter := range m.AdapterPaths {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "adapter",
+ Args: adapter,
+ })
+ }
+
+ for _, projector := range m.ProjectorPaths {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "model",
+ Args: projector,
+ })
+ }
+
+ if m.Template != "" {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "template",
+ Args: m.Template,
+ })
+ }
+
+ if m.System != "" {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "system",
+ Args: m.System,
+ })
+ }
+
+ for k, v := range m.Options {
+ switch v := v.(type) {
+ case []any:
+ for _, s := range v {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: k,
+ Args: fmt.Sprintf("%v", s),
+ })
+ }
+ default:
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: k,
+ Args: fmt.Sprintf("%v", v),
+ })
+ }
+ }
+
+ for _, license := range m.License {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "license",
+ Args: license,
+ })
+ }
+
+ for _, msg := range m.Messages {
+ modelfile.Commands = append(modelfile.Commands, model.Command{
+ Name: "message",
+ Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
+ })
+ }
+
+ return modelfile.String()
+}
+
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
@@ -85,50 +155,11 @@ type ConfigV2 struct {
RootFS RootFS `json:"rootfs"`
}
-func (c *ConfigV2) SetModelFormat(format string) {
- if c.ModelFormat == "" {
- c.ModelFormat = format
- }
-}
-
-func (c *ConfigV2) SetModelFamily(families ...string) {
- for _, family := range families {
- if c.ModelFamily == "" {
- c.ModelFamily = family
- }
-
- if !slices.Contains(c.ModelFamilies, family) {
- c.ModelFamilies = append(c.ModelFamilies, family)
- }
- }
-}
-
-func (c *ConfigV2) SetModelType(modelType string) {
- if c.ModelType == "" {
- c.ModelType = modelType
- }
-}
-
-func (c *ConfigV2) SetFileType(fileType string) {
- if c.FileType == "" {
- c.FileType = fileType
- }
-}
-
type RootFS struct {
Type string `json:"type"`
DiffIDs []string `json:"diff_ids"`
}
-func (m *ManifestV2) GetTotalSize() (total int64) {
- for _, layer := range m.Layers {
- total += layer.Size
- }
-
- total += m.Config.Size
- return total
-}
-
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
@@ -169,7 +200,6 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
- Size: manifest.GetTotalSize(),
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -259,7 +289,7 @@ func GetModel(name string) (*Model, error) {
return model, nil
}
-func realpath(mfDir, from string) string {
+func realpath(rel, from string) string {
abspath, err := filepath.Abs(from)
if err != nil {
return from
@@ -276,22 +306,15 @@ func realpath(mfDir, from string) string {
return filepath.Join(home, from[2:])
}
- if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil {
+ if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
// this is a file relative to the Modelfile
- return filepath.Join(mfDir, from)
+ return filepath.Join(rel, from)
}
return abspath
}
-func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
- deleteMap := make(map[string]struct{})
- if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
- for _, layer := range append(manifest.Layers, manifest.Config) {
- deleteMap[layer.Digest] = struct{}{}
- }
- }
-
+func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
@@ -300,250 +323,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
},
}
- var layers Layers
- messages := []string{}
+ var messages []*api.Message
+ parameters := make(map[string]any)
- params := make(map[string][]string)
- fromParams := make(map[string]any)
-
- for _, c := range commands {
+ var layers []*Layer
+ for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name {
- case "model":
- if strings.HasPrefix(c.Args, "@") {
- blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
+ case "model", "adapter":
+ var baseLayers []*layerWithGGML
+ if name := model.ParseName(c.Args); name.IsValid() {
+ baseLayers, err = parseFromModel(ctx, name, fn)
+ if err != nil {
+ return err
+ }
+ } else if strings.HasPrefix(c.Args, "@") {
+ blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
- c.Args = blobPath
- }
-
- pathName := realpath(modelFileDir, c.Args)
-
- ggufName, err := convertModel(name, pathName, fn)
- if err != nil {
- var pathErr *fs.PathError
- switch {
- case errors.Is(err, zip.ErrFormat):
- // it's not a safetensor archive
- case errors.As(err, &pathErr):
- // it's not a file on disk, could be a model reference
- default:
+ blob, err := os.Open(blobpath)
+ if err != nil {
return err
}
+ defer blob.Close()
+
+ baseLayers, err = parseFromFile(ctx, blob, fn)
+ if err != nil {
+ return err
+ }
+ } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
+ defer file.Close()
+
+ baseLayers, err = parseFromFile(ctx, file, fn)
+ if err != nil {
+ return err
+ }
+ } else {
+ return fmt.Errorf("invalid model reference: %s", c.Args)
}
- if ggufName != "" {
- pathName = ggufName
- defer os.RemoveAll(ggufName)
-
- if quantization != "" {
- quantization = strings.ToUpper(quantization)
- fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
- tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
+ for _, baseLayer := range baseLayers {
+ if quantization != "" &&
+ baseLayer.MediaType == "application/vnd.ollama.image.model" &&
+ baseLayer.GGML != nil &&
+ baseLayer.GGML.Name() == "gguf" {
+ want, err := llm.ParseFileType(quantization)
if err != nil {
return err
}
- defer os.RemoveAll(tempfile.Name())
- if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
- return err
- }
+ ft := baseLayer.GGML.KV().FileType()
+ if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
+ return errors.New("quantization is only supported for F16 and F32 models")
+ } else if want != ft {
+ fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
- if err := tempfile.Close(); err != nil {
- return err
- }
-
- pathName = tempfile.Name()
- }
- }
-
- bin, err := os.Open(pathName)
- if err != nil {
- // not a file on disk so must be a model reference
- modelpath := ParseModelPath(c.Args)
- manifest, _, err := GetManifest(modelpath)
- switch {
- case errors.Is(err, os.ErrNotExist):
- fn(api.ProgressResponse{Status: "pulling model"})
- if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil {
- return err
- }
-
- manifest, _, err = GetManifest(modelpath)
- if err != nil {
- return err
- }
- case err != nil:
- return err
- }
-
- fn(api.ProgressResponse{Status: "reading model metadata"})
- fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
- if err != nil {
- return err
- }
-
- fromConfigFile, err := os.Open(fromConfigPath)
- if err != nil {
- return err
- }
- defer fromConfigFile.Close()
-
- var fromConfig ConfigV2
- if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
- return err
- }
-
- // if the model is still not in gguf format, error out
- if fromConfig.ModelFormat != "gguf" {
- return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
- }
-
- config.SetModelFormat(fromConfig.ModelFormat)
- config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
- config.SetModelType(fromConfig.ModelType)
- config.SetFileType(fromConfig.FileType)
-
- for _, layer := range manifest.Layers {
- deleteMap[layer.Digest] = struct{}{}
- if layer.MediaType == "application/vnd.ollama.image.params" {
- fromParamsPath, err := GetBlobsPath(layer.Digest)
+ blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil {
return err
}
- fromParamsFile, err := os.Open(fromParamsPath)
+ temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
if err != nil {
return err
}
- defer fromParamsFile.Close()
+ defer temp.Close()
+ defer os.Remove(temp.Name())
- if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
+ if err := llm.Quantize(blob, temp.Name(), want); err != nil {
+ return err
+ }
+
+ baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
+ if err != nil {
return err
}
}
-
- layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
- if err != nil {
- return err
- }
-
- layers.Add(layer)
}
- deleteMap[manifest.Config.Digest] = struct{}{}
- continue
+ if baseLayer.GGML != nil {
+ config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
+ config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
+ config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
+ config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
+ config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
+ }
+
+ layers = append(layers, baseLayer.Layer)
}
- defer bin.Close()
-
- var offset int64
- for {
- fn(api.ProgressResponse{Status: "creating model layer"})
- if _, err := bin.Seek(offset, io.SeekStart); err != nil {
- return err
- }
-
- ggml, size, err := llm.DecodeGGML(bin)
- if errors.Is(err, io.EOF) {
- break
- } else if errors.Is(err, llm.ErrUnsupportedFormat) {
- return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
- } else if err != nil {
- return err
- }
-
- config.SetModelFormat(ggml.Name())
- config.SetModelFamily(ggml.KV().Architecture())
- config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
- config.SetFileType(ggml.KV().FileType())
-
- mediatype := mediatype
- if ggml.KV().Architecture() == "clip" {
- mediatype = "application/vnd.ollama.image.projector"
- }
-
- sr := io.NewSectionReader(bin, offset, size)
- layer, err := NewLayer(sr, mediatype)
- if err != nil {
- return err
- }
-
- layers.Add(layer)
-
- offset += size
- }
- case "adapter":
- if strings.HasPrefix(c.Args, "@") {
- blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
- if err != nil {
- return err
- }
-
- c.Args = blobPath
- }
-
- fn(api.ProgressResponse{Status: "creating adapter layer"})
- bin, err := os.Open(realpath(modelFileDir, c.Args))
- if err != nil {
- return err
- }
- defer bin.Close()
-
- _, size, err := llm.DecodeGGML(bin)
+ case "license", "template", "system":
+ blob := strings.NewReader(c.Args)
+ layer, err := NewLayer(blob, mediatype)
if err != nil {
return err
}
- sr := io.NewSectionReader(bin, 0, size)
- layer, err := NewLayer(sr, mediatype)
- if err != nil {
- return err
+ if c.Name != "license" {
+ // replace
+ layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
+ return layer.MediaType == mediatype
+ })
}
- layers.Add(layer)
- case "license":
- fn(api.ProgressResponse{Status: "creating license layer"})
-
- bin := strings.NewReader(c.Args)
- layer, err := NewLayer(bin, mediatype)
- if err != nil {
- return err
- }
-
- layers.Add(layer)
- case "template", "system":
- fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
-
- bin := strings.NewReader(c.Args)
- layer, err := NewLayer(bin, mediatype)
- if err != nil {
- return err
- }
-
- layers.Replace(layer)
+ layers = append(layers, layer)
case "message":
- messages = append(messages, c.Args)
+ role, content, ok := strings.Cut(c.Args, ": ")
+ if !ok {
+ return fmt.Errorf("invalid message: %s", c.Args)
+ }
+
+ messages = append(messages, &api.Message{Role: role, Content: content})
default:
- params[c.Name] = append(params[c.Name], c.Args)
+ ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
+ if err != nil {
+ return err
+ }
+
+ for k, v := range ps {
+ if ks, ok := parameters[k].([]string); ok {
+ parameters[k] = append(ks, v.([]string)...)
+ } else if vs, ok := v.([]string); ok {
+ parameters[k] = vs
+ } else {
+ parameters[k] = v
+ }
+ }
}
}
- if len(messages) > 0 {
- fn(api.ProgressResponse{Status: "creating parameters layer"})
+ var err2 error
+ layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
+ switch layer.MediaType {
+ case "application/vnd.ollama.image.message":
+ // if there are new messages, remove the inherited ones
+ if len(messages) > 0 {
+ return true
+ }
- msgs := make([]api.Message, 0)
+ return false
+ case "application/vnd.ollama.image.params":
+ // merge inherited parameters with new ones
+ r, err := layer.Open()
+ if err != nil {
+ err2 = err
+ return false
+ }
+ defer r.Close()
- for _, m := range messages {
- // todo: handle images
- msg := strings.SplitN(m, ": ", 2)
- msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
+ var ps map[string]any
+ if err := json.NewDecoder(r).Decode(&ps); err != nil {
+ err2 = err
+ return false
+ }
+
+ for k, v := range ps {
+ if _, ok := parameters[k]; !ok {
+ parameters[k] = v
+ }
+ }
+
+ return true
+ default:
+ return false
}
+ })
+ if err2 != nil {
+ return err2
+ }
+
+ if len(messages) > 0 {
var b bytes.Buffer
- if err := json.NewEncoder(&b).Encode(msgs); err != nil {
+ if err := json.NewEncoder(&b).Encode(messages); err != nil {
return err
}
@@ -552,39 +506,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err
}
- layers.Replace(layer)
+ layers = append(layers, layer)
}
- if len(params) > 0 {
- fn(api.ProgressResponse{Status: "creating parameters layer"})
-
- formattedParams, err := api.FormatParams(params)
- if err != nil {
- return err
- }
-
- for k, v := range fromParams {
- if _, ok := formattedParams[k]; !ok {
- formattedParams[k] = v
- }
- }
-
+ if len(parameters) > 0 {
var b bytes.Buffer
- if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
+ if err := json.NewEncoder(&b).Encode(parameters); err != nil {
return err
}
- fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return err
}
- layers.Replace(layer)
+ layers = append(layers, layer)
}
- digests := make([]string, len(layers.items))
- for i, layer := range layers.items {
+ digests := make([]string, len(layers))
+ for i, layer := range layers {
digests[i] = layer.Digest
}
@@ -595,36 +535,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err
}
- configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+ layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
- delete(deleteMap, configLayer.Digest)
+ for _, layer := range append(layers, layer) {
+ if layer.status != "" {
+ fn(api.ProgressResponse{Status: layer.status})
+ }
+ }
- for _, layer := range append(layers.items, configLayer) {
- committed, err := layer.Commit()
- if err != nil {
- return err
+ unref := make(map[string]struct{})
+ if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
+ for _, layer := range manifest.Layers {
+ if !slices.Contains(digests, layer.Digest) {
+ unref[layer.Digest] = struct{}{}
+ }
}
- status := "writing layer"
- if !committed {
- status = "using already created layer"
+ if manifest.Config.Digest != layer.Digest {
+ unref[manifest.Config.Digest] = struct{}{}
}
-
- fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
-
- delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
- if err := WriteManifest(name, configLayer, layers.items); err != nil {
+ if err := WriteManifest(name, layer, layers); err != nil {
return err
}
- if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
- if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
+ if !envconfig.NoPrune {
+ if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err
}
}
@@ -633,104 +574,43 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return nil
}
-func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
- r, err := zip.OpenReader(path)
- if err != nil {
- return "", err
+func CopyModel(src, dst model.Name) error {
+ if !dst.IsFullyQualified() {
+ return model.Unqualified(dst)
}
- defer r.Close()
-
- tempDir, err := os.MkdirTemp("", "ollama-convert")
- if err != nil {
- return "", err
- }
- defer os.RemoveAll(tempDir)
-
- fn(api.ProgressResponse{Status: "unpacking model metadata"})
- for _, f := range r.File {
- fpath := filepath.Join(tempDir, f.Name)
- outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
- if err != nil {
- return "", err
- }
-
- rc, err := f.Open()
- if err != nil {
- return "", err
- }
-
- _, err = io.Copy(outFile, rc)
- if err != nil {
- return "", err
- }
-
- outFile.Close()
- rc.Close()
+ if !src.IsFullyQualified() {
+ return model.Unqualified(src)
}
- mf, err := convert.GetModelFormat(tempDir)
- if err != nil {
- return "", err
+ if src.Filepath() == dst.Filepath() {
+ return nil
}
- params, err := mf.GetParams(tempDir)
- if err != nil {
- return "", err
- }
-
- mArch, err := mf.GetModelArch(name, tempDir, params)
- if err != nil {
- return "", err
- }
-
- fn(api.ProgressResponse{Status: "processing tensors"})
- if err := mArch.GetTensors(); err != nil {
- return "", err
- }
-
- if err := mArch.LoadVocab(); err != nil {
- return "", err
- }
-
- fn(api.ProgressResponse{Status: "converting model"})
- path, err = mArch.WriteGGUF()
- if err != nil {
- return "", err
- }
-
- return path, nil
-}
-
-func CopyModel(src, dest string) error {
- srcModelPath := ParseModelPath(src)
- srcPath, err := srcModelPath.GetManifestPath()
+ manifests, err := GetManifestPath()
if err != nil {
return err
}
- destModelPath := ParseModelPath(dest)
- destPath, err := destModelPath.GetManifestPath()
- if err != nil {
- return err
- }
- if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
+ dstpath := filepath.Join(manifests, dst.Filepath())
+ if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
- // copy the file
- input, err := os.ReadFile(srcPath)
+ srcpath := filepath.Join(manifests, src.Filepath())
+ srcfile, err := os.Open(srcpath)
if err != nil {
- fmt.Println("Error reading file:", err)
return err
}
+ defer srcfile.Close()
- err = os.WriteFile(destPath, input, 0o644)
+ dstfile, err := os.Create(dstpath)
if err != nil {
- fmt.Println("Error reading file:", err)
return err
}
+ defer dstfile.Close()
- return nil
+ _, err = io.Copy(dstfile, srcfile)
+ return err
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
@@ -890,67 +770,6 @@ func DeleteModel(name string) error {
return nil
}
-func ShowModelfile(model *Model) (string, error) {
- var mt struct {
- *Model
- From string
- Parameters map[string][]any
- }
-
- mt.Parameters = make(map[string][]any)
- for k, v := range model.Options {
- if s, ok := v.([]any); ok {
- mt.Parameters[k] = s
- continue
- }
-
- mt.Parameters[k] = []any{v}
- }
-
- mt.Model = model
- mt.From = model.ModelPath
-
- if model.ParentModel != "" {
- mt.From = model.ParentModel
- }
-
- modelFile := `# Modelfile generated by "ollama show"
-# To build a new Modelfile based on this one, replace the FROM line with:
-# FROM {{ .ShortName }}
-
-FROM {{ .From }}
-TEMPLATE """{{ .Template }}"""
-
-{{- if .System }}
-SYSTEM """{{ .System }}"""
-{{- end }}
-
-{{- range $adapter := .AdapterPaths }}
-ADAPTER {{ $adapter }}
-{{- end }}
-
-{{- range $k, $v := .Parameters }}
-{{- range $parameter := $v }}
-PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
-{{- end }}
-{{- end }}`
-
- tmpl, err := template.New("").Parse(modelFile)
- if err != nil {
- slog.Info(fmt.Sprintf("error parsing template: %q", err))
- return "", err
- }
-
- var buf bytes.Buffer
-
- if err = tmpl.Execute(&buf, mt); err != nil {
- slog.Info(fmt.Sprintf("error executing template: %q", err))
- return "", err
- }
-
- return buf.String(), nil
-}
-
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -972,9 +791,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
- if errors.Is(err, errUnauthorized) {
- return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
- }
return err
}
}
@@ -1011,7 +827,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
- if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+ if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
@@ -1137,9 +953,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
-var errUnauthorized = fmt.Errorf("unauthorized")
+var errUnauthorized = fmt.Errorf("unauthorized: access denied")
+
+// getTokenSubject returns the subject of a JWT token, it does not validate the token
+func getTokenSubject(token string) string {
+ parts := strings.Split(token, ".")
+ if len(parts) != 3 {
+ slog.Error("jwt token does not contain 3 parts")
+ return ""
+ }
+
+ payload := parts[1]
+ payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
+ if err != nil {
+ slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
+ return ""
+ }
+
+ var payloadMap map[string]interface{}
+ if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
+ slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
+ return ""
+ }
+
+ sub, ok := payloadMap["sub"]
+ if !ok {
+ slog.Error("jwt does not contain 'sub' field")
+ return ""
+ }
+
+ return fmt.Sprintf("%s", sub)
+}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
+ anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
@@ -1158,6 +1005,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, err
}
+ anonymous = getTokenSubject(token) == "anonymous"
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
@@ -1178,6 +1026,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
}
}
+ if anonymous {
+ // no user is associated with the public key, and the request requires non-anonymous access
+ pubKey, nestedErr := auth.GetPublicKey()
+ if nestedErr != nil {
+ slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
+ return nil, errUnauthorized
+ }
+ return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
+ }
+ // user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized
}
@@ -1255,7 +1113,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
}
}
-var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
+var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
diff --git a/server/layers.go b/server/layer.go
similarity index 53%
rename from server/layers.go
rename to server/layer.go
index 07787406..dcca3854 100644
--- a/server/layers.go
+++ b/server/layer.go
@@ -5,39 +5,14 @@ import (
"fmt"
"io"
"os"
- "strings"
-
- "golang.org/x/exp/slices"
)
-type Layers struct {
- items []*Layer
-}
-
-func (ls *Layers) Add(layer *Layer) {
- if layer.Size > 0 {
- ls.items = append(ls.items, layer)
- }
-}
-
-func (ls *Layers) Replace(layer *Layer) {
- if layer.Size > 0 {
- mediatype := layer.MediaType
- layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
- return l.MediaType == mediatype
- })
-
- ls.items = append(layers, layer)
- }
-}
-
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
-
- tempFileName string
+ status string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
@@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err
}
- const delimiter = "-"
-
- pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
- temp, err := os.CreateTemp(blobs, pattern)
+ temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil {
return nil, err
}
defer temp.Close()
+ defer os.Remove(temp.Name())
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
@@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err
}
+ if err := temp.Close(); err != nil {
+ return nil, err
+ }
+
+ digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
+ blob, err := GetBlobsPath(digest)
+ if err != nil {
+ return nil, err
+ }
+
+ status := "using existing layer"
+ if _, err := os.Stat(blob); err != nil {
+ status = "creating new layer"
+ if err := os.Rename(temp.Name(), blob); err != nil {
+ return nil, err
+ }
+ }
+
return &Layer{
- MediaType: mediatype,
- Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
- Size: n,
- tempFileName: temp.Name(),
+ MediaType: mediatype,
+ Digest: digest,
+ Size: n,
+ status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
+ status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
-func (l *Layer) Commit() (bool, error) {
- // always remove temp
- defer os.Remove(l.tempFileName)
-
+func (l *Layer) Open() (io.ReadCloser, error) {
blob, err := GetBlobsPath(l.Digest)
if err != nil {
- return false, err
+ return nil, err
}
- if _, err := os.Stat(blob); err != nil {
- return true, os.Rename(l.tempFileName, blob)
- }
-
- return false, nil
+ return os.Open(blob)
}
diff --git a/server/manifest.go b/server/manifest.go
new file mode 100644
index 00000000..8a17700e
--- /dev/null
+++ b/server/manifest.go
@@ -0,0 +1,79 @@
+package server
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+
+ "github.com/ollama/ollama/types/model"
+)
+
+type Manifest struct {
+ ManifestV2
+ Digest string `json:"-"`
+}
+
+func (m *Manifest) Size() (size int64) {
+ for _, layer := range append(m.Layers, m.Config) {
+ size += layer.Size
+ }
+
+ return
+}
+
+func ParseNamedManifest(name model.Name) (*Manifest, error) {
+ if !name.IsFullyQualified() {
+ return nil, model.Unqualified(name)
+ }
+
+ manifests, err := GetManifestPath()
+ if err != nil {
+ return nil, err
+ }
+
+ var manifest ManifestV2
+ manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
+ if err != nil {
+ return nil, err
+ }
+
+ sha256sum := sha256.New()
+ if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
+ return nil, err
+ }
+
+ return &Manifest{
+ ManifestV2: manifest,
+ Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
+ }, nil
+}
+
+func WriteManifest(name string, config *Layer, layers []*Layer) error {
+ manifest := ManifestV2{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Config: config,
+ Layers: layers,
+ }
+
+ var b bytes.Buffer
+ if err := json.NewEncoder(&b).Encode(manifest); err != nil {
+ return err
+ }
+
+ modelpath := ParseModelPath(name)
+ manifestPath, err := modelpath.GetManifestPath()
+ if err != nil {
+ return err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
+ return err
+ }
+
+ return os.WriteFile(manifestPath, b.Bytes(), 0o644)
+}
diff --git a/server/manifests.go b/server/manifests.go
deleted file mode 100644
index 2b39db65..00000000
--- a/server/manifests.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package server
-
-import (
- "bytes"
- "encoding/json"
- "os"
- "path/filepath"
-)
-
-func WriteManifest(name string, config *Layer, layers []*Layer) error {
- manifest := ManifestV2{
- SchemaVersion: 2,
- MediaType: "application/vnd.docker.distribution.manifest.v2+json",
- Config: config,
- Layers: layers,
- }
-
- var b bytes.Buffer
- if err := json.NewEncoder(&b).Encode(manifest); err != nil {
- return err
- }
-
- modelpath := ParseModelPath(name)
- manifestPath, err := modelpath.GetManifestPath()
- if err != nil {
- return err
- }
-
- if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
- return err
- }
-
- return os.WriteFile(manifestPath, b.Bytes(), 0o644)
-}
diff --git a/server/model.go b/server/model.go
new file mode 100644
index 00000000..eea5d13a
--- /dev/null
+++ b/server/model.go
@@ -0,0 +1,261 @@
+package server
+
+import (
+ "archive/zip"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/convert"
+ "github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/types/model"
+)
+
+type layerWithGGML struct {
+ *Layer
+ *llm.GGML
+}
+
+func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+ modelpath := ParseModelPath(name.String())
+ manifest, _, err := GetManifest(modelpath)
+ switch {
+ case errors.Is(err, os.ErrNotExist):
+ if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
+ return nil, err
+ }
+
+ modelpath = ParseModelPath(name.String())
+ manifest, _, err = GetManifest(modelpath)
+ if err != nil {
+ return nil, err
+ }
+ case err != nil:
+ return nil, err
+ }
+
+ for _, layer := range manifest.Layers {
+ layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
+ if err != nil {
+ return nil, err
+ }
+
+ switch layer.MediaType {
+ case "application/vnd.ollama.image.model",
+ "application/vnd.ollama.image.projector",
+ "application/vnd.ollama.image.adapter":
+ blobpath, err := GetBlobsPath(layer.Digest)
+ if err != nil {
+ return nil, err
+ }
+
+ blob, err := os.Open(blobpath)
+ if err != nil {
+ return nil, err
+ }
+ defer blob.Close()
+
+ ggml, _, err := llm.DecodeGGML(blob)
+ if err != nil {
+ return nil, err
+ }
+
+ layers = append(layers, &layerWithGGML{layer, ggml})
+ default:
+ layers = append(layers, &layerWithGGML{layer, nil})
+ }
+
+ }
+
+ return layers, nil
+}
+
+func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+ stat, err := file.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ r, err := zip.NewReader(file, stat.Size())
+ if err != nil {
+ return nil, err
+ }
+
+ tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+ if err != nil {
+ return nil, err
+ }
+ defer os.RemoveAll(tempdir)
+
+ fn(api.ProgressResponse{Status: "unpacking model metadata"})
+ for _, f := range r.File {
+ // TODO(mxyng): this should not write out all files to disk
+ outfile, err := os.Create(filepath.Join(tempdir, f.Name))
+ if err != nil {
+ return nil, err
+ }
+ defer outfile.Close()
+
+ infile, err := f.Open()
+ if err != nil {
+ return nil, err
+ }
+ defer infile.Close()
+
+ if _, err = io.Copy(outfile, infile); err != nil {
+ return nil, err
+ }
+
+ if err := outfile.Close(); err != nil {
+ return nil, err
+ }
+
+ if err := infile.Close(); err != nil {
+ return nil, err
+ }
+ }
+
+ mf, err := convert.GetModelFormat(tempdir)
+ if err != nil {
+ return nil, err
+ }
+
+ params, err := mf.GetParams(tempdir)
+ if err != nil {
+ return nil, err
+ }
+
+ mArch, err := mf.GetModelArch("", tempdir, params)
+ if err != nil {
+ return nil, err
+ }
+
+ fn(api.ProgressResponse{Status: "processing tensors"})
+ if err := mArch.GetTensors(); err != nil {
+ return nil, err
+ }
+
+ if err := mArch.LoadVocab(); err != nil {
+ return nil, err
+ }
+
+ fn(api.ProgressResponse{Status: "converting model"})
+
+ // TODO(mxyng): this should write directly into a layer
+ // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
+ temp, err := os.CreateTemp(tempdir, "fp16")
+ if err != nil {
+ return nil, err
+ }
+ defer temp.Close()
+ defer os.Remove(temp.Name())
+
+ if err = mArch.WriteGGUF(temp); err != nil {
+ return nil, err
+ }
+
+ if _, err := temp.Seek(0, io.SeekStart); err != nil {
+ return nil, err
+ }
+
+ layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
+ if err != nil {
+ return nil, fmt.Errorf("aaa: %w", err)
+ }
+
+ blobpath, err := GetBlobsPath(layer.Digest)
+ if err != nil {
+ return nil, err
+ }
+
+ bin, err := os.Open(blobpath)
+ if err != nil {
+ return nil, err
+ }
+ defer bin.Close()
+
+ ggml, _, err := llm.DecodeGGML(bin)
+ if err != nil {
+ return nil, err
+ }
+
+ layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
+ if err != nil {
+ return nil, err
+ }
+
+ layers = append(layers, &layerWithGGML{layer, ggml})
+ return layers, nil
+}
+
+func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
+ sr := io.NewSectionReader(file, 0, 512)
+ contentType, err := detectContentType(sr)
+ if err != nil {
+ return nil, err
+ }
+
+ switch contentType {
+ case "gguf", "ggla":
+ // noop
+ case "application/zip":
+ return parseFromZipFile(ctx, file, fn)
+ default:
+ return nil, fmt.Errorf("unsupported content type: %s", contentType)
+ }
+
+ stat, err := file.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ var offset int64
+ for offset < stat.Size() {
+ ggml, n, err := llm.DecodeGGML(file)
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ mediatype := "application/vnd.ollama.image.model"
+ if ggml.Name() == "ggla" {
+ mediatype = "application/vnd.ollama.image.adapter"
+ } else if ggml.KV().Architecture() == "clip" {
+ mediatype = "application/vnd.ollama.image.projector"
+ }
+
+ layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
+ if err != nil {
+ return nil, err
+ }
+
+ layers = append(layers, &layerWithGGML{layer, ggml})
+ offset = n
+ }
+
+ return layers, nil
+}
+
+func detectContentType(r io.Reader) (string, error) {
+ var b bytes.Buffer
+ if _, err := io.Copy(&b, r); err != nil {
+ return "", err
+ }
+
+ if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
+ return contentType, nil
+ }
+
+ if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
+ return contentType, nil
+ }
+
+ return "unknown", nil
+}
diff --git a/server/modelpath.go b/server/modelpath.go
index 7d333876..86908226 100644
--- a/server/modelpath.go
+++ b/server/modelpath.go
@@ -6,6 +6,7 @@ import (
"net/url"
"os"
"path/filepath"
+ "regexp"
"strings"
)
@@ -25,9 +26,10 @@ const (
)
var (
- ErrInvalidImageFormat = errors.New("invalid image format")
- ErrInvalidProtocol = errors.New("invalid protocol scheme")
- ErrInsecureProtocol = errors.New("insecure protocol http")
+ ErrInvalidImageFormat = errors.New("invalid image format")
+ ErrInvalidProtocol = errors.New("invalid protocol scheme")
+ ErrInsecureProtocol = errors.New("insecure protocol http")
+ ErrInvalidDigestFormat = errors.New("invalid digest format")
)
func ParseModelPath(name string) ModelPath {
@@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) {
return "", err
}
+ // only accept actual sha256 digests
+ pattern := "^sha256[:-][0-9a-fA-F]{64}$"
+ re := regexp.MustCompile(pattern)
+ if err != nil {
+ return "", err
+ }
+
+ if digest != "" && !re.MatchString(digest) {
+ return "", ErrInvalidDigestFormat
+ }
+
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(dir, "blobs", digest)
dirPath := filepath.Dir(path)
diff --git a/server/modelpath_test.go b/server/modelpath_test.go
index 8b26d52c..30741d87 100644
--- a/server/modelpath_test.go
+++ b/server/modelpath_test.go
@@ -1,6 +1,73 @@
package server
-import "testing"
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGetBlobsPath(t *testing.T) {
+ // GetBlobsPath expects an actual directory to exist
+ dir, err := os.MkdirTemp("", "ollama-test")
+ assert.Nil(t, err)
+ defer os.RemoveAll(dir)
+
+ tests := []struct {
+ name string
+ digest string
+ expected string
+ err error
+ }{
+ {
+ "empty digest",
+ "",
+ filepath.Join(dir, "blobs"),
+ nil,
+ },
+ {
+ "valid with colon",
+ "sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
+ filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
+ nil,
+ },
+ {
+ "valid with dash",
+ "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
+ filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
+ nil,
+ },
+ {
+ "digest too short",
+ "sha256-45640291",
+ "",
+ ErrInvalidDigestFormat,
+ },
+ {
+ "digest too long",
+ "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
+ "",
+ ErrInvalidDigestFormat,
+ },
+ {
+ "digest invalid chars",
+ "../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
+ "",
+ ErrInvalidDigestFormat,
+ },
+ }
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Setenv("OLLAMA_MODELS", dir)
+
+ got, err := GetBlobsPath(tc.digest)
+
+ assert.ErrorIs(t, tc.err, err, tc.name)
+ assert.Equal(t, tc.expected, got, tc.name)
+ })
+ }
+}
func TestParseModelPath(t *testing.T) {
tests := []struct {
diff --git a/server/routes.go b/server/routes.go
index b0d36b14..7dfeb513 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1,6 +1,7 @@
package server
import (
+ "cmp"
"context"
"encoding/json"
"errors"
@@ -15,11 +16,8 @@ import (
"os"
"os/signal"
"path/filepath"
- "reflect"
- "runtime"
"strconv"
"strings"
- "sync"
"syscall"
"time"
@@ -31,14 +29,16 @@ import (
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
- "github.com/ollama/ollama/parser"
+ "github.com/ollama/ollama/server/envconfig"
+ "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var mode string = gin.DebugMode
type Server struct {
- addr net.Addr
+ addr net.Addr
+ sched *Scheduler
}
func init() {
@@ -53,88 +53,8 @@ func init() {
gin.SetMode(mode)
}
-var loaded struct {
- mu sync.Mutex
-
- llama *llm.LlamaServer
-
- expireTimer *time.Timer
-
- model string
- adapters []string
- projectors []string
- *api.Options
-}
-
var defaultSessionDuration = 5 * time.Minute
-func unload() {
- if loaded.llama != nil {
- loaded.llama.Close()
- }
-
- loaded.llama = nil
- loaded.model = ""
- loaded.adapters = nil
- loaded.projectors = nil
- loaded.Options = nil
-}
-
-// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
-func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
- ctx, cancel := context.WithTimeout(c, 10*time.Second)
- defer cancel()
-
- needLoad := loaded.llama == nil || // is there a model loaded?
- loaded.model != model.ModelPath || // has the base model changed?
- !reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
- !reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
- !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
- loaded.llama.Ping(ctx) != nil
-
- if needLoad {
- if loaded.llama != nil {
- slog.Info("changing loaded model")
- unload()
- }
-
- llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
- if err != nil {
- // some older models are not compatible with newer versions of llama.cpp
- // show a generalized compatibility error until there is a better way to
- // check for model compatibility
- if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
- err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
- }
-
- return err
- }
-
- loaded.model = model.ModelPath
- loaded.adapters = model.AdapterPaths
- loaded.projectors = model.ProjectorPaths
- loaded.llama = llama
- loaded.Options = &opts
-
- if err = llama.WaitUntilRunning(); err != nil {
- slog.Error("error loading llama server", "error", err)
- unload()
- return err
- }
- }
-
- if loaded.expireTimer == nil {
- loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
- loaded.mu.Lock()
- defer loaded.mu.Unlock()
- unload()
- })
- }
-
- loaded.expireTimer.Reset(sessionDuration)
- return nil
-}
-
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -154,9 +74,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType)
}
-func GenerateHandler(c *gin.Context) {
- loaded.mu.Lock()
- defer loaded.mu.Unlock()
+func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@@ -224,8 +142,12 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
- if err := load(c, model, opts, sessionDuration); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+ var runner *runnerRef
+ select {
+ case runner = <-rCh:
+ case err = <-eCh:
+ handleErrorResponse(c, err)
return
}
@@ -275,7 +197,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
- prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
+ prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -297,9 +219,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
- // Update model expiration
- loaded.expireTimer.Reset(sessionDuration)
-
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
@@ -331,7 +250,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
- tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
+ tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
@@ -359,7 +278,7 @@ func GenerateHandler(c *gin.Context) {
Images: images,
Options: opts,
}
- if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
+ if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -421,10 +340,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration
}
-func EmbeddingsHandler(c *gin.Context) {
- loaded.mu.Lock()
- defer loaded.mu.Unlock()
-
+func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -469,8 +385,12 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
- if err := load(c, model, opts, sessionDuration); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+ var runner *runnerRef
+ select {
+ case runner = <-rCh:
+ case err = <-eCh:
+ handleErrorResponse(c, err)
return
}
@@ -480,7 +400,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
- embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
+ embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -493,7 +413,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
-func PullModelHandler(c *gin.Context) {
+func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -542,7 +462,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
-func PushModelHandler(c *gin.Context) {
+func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -591,30 +511,19 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
-func CreateModelHandler(c *gin.Context) {
+func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
- err := c.ShouldBindJSON(&req)
- switch {
- case errors.Is(err, io.EOF):
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
- case err != nil:
+ } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- var model string
- if req.Model != "" {
- model = req.Model
- } else if req.Name != "" {
- model = req.Name
- } else {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
- return
- }
-
- if err := ParseModelPath(model).Validate(); err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ name := model.ParseName(cmp.Or(req.Model, req.Name))
+ if !name.IsValid() {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
@@ -623,19 +532,19 @@ func CreateModelHandler(c *gin.Context) {
return
}
- var modelfile io.Reader = strings.NewReader(req.Modelfile)
+ var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
- mf, err := os.Open(req.Path)
+ f, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
- defer mf.Close()
+ defer f.Close()
- modelfile = mf
+ r = f
}
- commands, err := parser.Parse(modelfile)
+ modelfile, err := model.ParseFile(r)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -651,7 +560,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
- if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
+ if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -664,7 +573,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
-func DeleteModelHandler(c *gin.Context) {
+func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -709,7 +618,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
-func ShowModelHandler(c *gin.Context) {
+func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -799,109 +708,115 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
- mf, err := ShowModelfile(model)
- if err != nil {
- return nil, err
- }
-
- resp.Modelfile = mf
+ var sb strings.Builder
+ fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
+ fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
+ fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
+ fmt.Fprint(&sb, model.String())
+ resp.Modelfile = sb.String()
return resp, nil
}
-func ListModelsHandler(c *gin.Context) {
- models := make([]api.ModelResponse, 0)
- manifestsPath, err := GetManifestPath()
+func (s *Server) ListModelsHandler(c *gin.Context) {
+ manifests, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- modelResponse := func(modelName string) (api.ModelResponse, error) {
- model, err := GetModel(modelName)
- if err != nil {
- return api.ModelResponse{}, err
- }
-
- modelDetails := api.ModelDetails{
- Format: model.Config.ModelFormat,
- Family: model.Config.ModelFamily,
- Families: model.Config.ModelFamilies,
- ParameterSize: model.Config.ModelType,
- QuantizationLevel: model.Config.FileType,
- }
-
- return api.ModelResponse{
- Model: model.ShortName,
- Name: model.ShortName,
- Size: model.Size,
- Digest: model.Digest,
- Details: modelDetails,
- }, nil
- }
-
- walkFunc := func(path string, info os.FileInfo, _ error) error {
+ var models []api.ModelResponse
+ if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
- path, tag := filepath.Split(path)
- model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
- modelPath := strings.Join([]string{model, tag}, ":")
- canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
-
- resp, err := modelResponse(canonicalModelPath)
+ rel, err := filepath.Rel(manifests, path)
if err != nil {
- slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
- // nolint: nilerr
+ return err
+ }
+
+ if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
+ return err
+ } else if hidden {
return nil
}
- resp.ModifiedAt = info.ModTime()
- models = append(models, resp)
+ n := model.ParseNameFromFilepath(rel)
+ m, err := ParseNamedManifest(n)
+ if err != nil {
+ return err
+ }
+
+ f, err := m.Config.Open()
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ var c ConfigV2
+ if err := json.NewDecoder(f).Decode(&c); err != nil {
+ return err
+ }
+
+ // tag should never be masked
+ models = append(models, api.ModelResponse{
+ Model: n.DisplayShortest(),
+ Name: n.DisplayShortest(),
+ Size: m.Size(),
+ Digest: m.Digest,
+ ModifiedAt: info.ModTime(),
+ Details: api.ModelDetails{
+ Format: c.ModelFormat,
+ Family: c.ModelFamily,
+ Families: c.ModelFamilies,
+ ParameterSize: c.ModelType,
+ QuantizationLevel: c.FileType,
+ },
+ })
}
return nil
- }
-
- if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
+ }); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
+ slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
+ // most recently modified first
+ return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
+ })
+
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
-func CopyModelHandler(c *gin.Context) {
- var req api.CopyRequest
- err := c.ShouldBindJSON(&req)
- switch {
- case errors.Is(err, io.EOF):
+func (s *Server) CopyModelHandler(c *gin.Context) {
+ var r api.CopyRequest
+ if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
- case err != nil:
+ } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- if req.Source == "" || req.Destination == "" {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
+ src := model.ParseName(r.Source)
+ if !src.IsValid() {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
return
}
- if err := ParseModelPath(req.Destination).Validate(); err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ dst := model.ParseName(r.Destination)
+ if !dst.IsValid() {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
return
}
- if err := CopyModel(req.Source, req.Destination); err != nil {
- if os.IsNotExist(err) {
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
- } else {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- }
- return
+ if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
+ } else if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
-func HeadBlobHandler(c *gin.Context) {
+func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -916,7 +831,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
}
-func CreateBlobHandler(c *gin.Context) {
+func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -946,20 +861,9 @@ func CreateBlobHandler(c *gin.Context) {
return
}
- if _, err := layer.Commit(); err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
c.Status(http.StatusCreated)
}
-var defaultAllowOrigins = []string{
- "localhost",
- "127.0.0.1",
- "0.0.0.0",
-}
-
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
@@ -1031,6 +935,11 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
if allowedHost(host) {
+ if c.Request.Method == "OPTIONS" {
+ c.AbortWithStatus(http.StatusNoContent)
+ return
+ }
+
c.Next()
return
}
@@ -1043,19 +952,8 @@ func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
-
- if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
- config.AllowOrigins = strings.Split(allowedOrigins, ",")
- }
-
- for _, allowOrigin := range defaultAllowOrigins {
- config.AllowOrigins = append(config.AllowOrigins,
- fmt.Sprintf("http://%s", allowOrigin),
- fmt.Sprintf("https://%s", allowOrigin),
- fmt.Sprintf("http://%s:*", allowOrigin),
- fmt.Sprintf("https://%s:*", allowOrigin),
- )
- }
+ config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
+ config.AllowOrigins = envconfig.AllowOrigins
r := gin.Default()
r.Use(
@@ -1063,27 +961,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
- r.POST("/api/pull", PullModelHandler)
- r.POST("/api/generate", GenerateHandler)
- r.POST("/api/chat", ChatHandler)
- r.POST("/api/embeddings", EmbeddingsHandler)
- r.POST("/api/create", CreateModelHandler)
- r.POST("/api/push", PushModelHandler)
- r.POST("/api/copy", CopyModelHandler)
- r.DELETE("/api/delete", DeleteModelHandler)
- r.POST("/api/show", ShowModelHandler)
- r.POST("/api/blobs/:digest", CreateBlobHandler)
- r.HEAD("/api/blobs/:digest", HeadBlobHandler)
+ r.POST("/api/pull", s.PullModelHandler)
+ r.POST("/api/generate", s.GenerateHandler)
+ r.POST("/api/chat", s.ChatHandler)
+ r.POST("/api/embeddings", s.EmbeddingsHandler)
+ r.POST("/api/create", s.CreateModelHandler)
+ r.POST("/api/push", s.PushModelHandler)
+ r.POST("/api/copy", s.CopyModelHandler)
+ r.DELETE("/api/delete", s.DeleteModelHandler)
+ r.POST("/api/show", s.ShowModelHandler)
+ r.POST("/api/blobs/:digest", s.CreateBlobHandler)
+ r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints
- r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
+ r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
- r.Handle(method, "/api/tags", ListModelsHandler)
+ r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
@@ -1094,10 +992,11 @@ func (s *Server) GenerateRoutes() http.Handler {
func Serve(ln net.Listener) error {
level := slog.LevelInfo
- if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+ if envconfig.Debug {
level = slog.LevelDebug
}
+ slog.Info("server config", "env", envconfig.AsMap())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
@@ -1121,7 +1020,7 @@ func Serve(ln net.Listener) error {
return err
}
- if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+ if !envconfig.NoPrune {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
@@ -1137,7 +1036,9 @@ func Serve(ln net.Listener) error {
}
}
- s := &Server{addr: ln.Addr()}
+ ctx, done := context.WithCancel(context.Background())
+ sched := InitScheduler(ctx)
+ s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@@ -1150,7 +1051,9 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
- unload()
+ srvr.Close()
+ done()
+ sched.unloadAllRunners()
gpu.Cleanup()
os.Exit(0)
}()
@@ -1158,12 +1061,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
- if runtime.GOOS == "linux" { // TODO - windows too
- // check compatibility to log warnings
- if _, err := gpu.CheckVRAM(); err != nil {
- slog.Info(err.Error())
- }
- }
+
+ s.sched.Run(ctx)
+
+ // At startup we retrieve GPU information so we can get log messages before loading a model
+ // This will log warnings to the log in case we have problems with detected GPUs
+ _ = gpu.GetGPUInfo()
return srvr.Serve(ln)
}
@@ -1219,9 +1122,9 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
- return loaded.llama.Tokenize(ctx, s)
+ return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
@@ -1232,10 +1135,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil
}
-func ChatHandler(c *gin.Context) {
- loaded.mu.Lock()
- defer loaded.mu.Unlock()
-
+func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
@@ -1292,8 +1192,12 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
- if err := load(c, model, opts, sessionDuration); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+ var runner *runnerRef
+ select {
+ case runner = <-rCh:
+ case err = <-eCh:
+ handleErrorResponse(c, err)
return
}
@@ -1309,7 +1213,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...)
}
- prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
+ prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1352,8 +1256,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
- // Update model expiration
- loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
@@ -1376,7 +1278,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
- if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
+ if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
@@ -1416,3 +1318,15 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
+
+func handleErrorResponse(c *gin.Context, err error) {
+ if errors.Is(err, context.Canceled) {
+ c.JSON(499, gin.H{"error": "request canceled"})
+ return
+ }
+ if errors.Is(err, ErrMaxQueue) {
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
+ return
+ }
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+}
diff --git a/server/routes_test.go b/server/routes_test.go
index 4f907702..896dc27b 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
- "github.com/ollama/ollama/parser"
+ "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) {
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
- modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
- commands, err := parser.Parse(modelfile)
+ r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
+ modelfile, err := model.ParseFile(r)
assert.Nil(t, err)
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
- err = CreateModel(context.TODO(), name, "", "", commands, fn)
+ err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
assert.Nil(t, err)
}
@@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
Method: http.MethodPost,
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
- f, err := os.CreateTemp(t.TempDir(), "ollama-model")
- assert.Nil(t, err)
- defer f.Close()
+ fname := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
- Modelfile: fmt.Sprintf("FROM %s", f.Name()),
+ Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
@@ -216,28 +214,25 @@ func Test_Routes(t *testing.T) {
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)
- workDir, err := os.MkdirTemp("", "ollama-test")
- assert.Nil(t, err)
- defer os.RemoveAll(workDir)
- os.Setenv("OLLAMA_MODELS", workDir)
+ t.Setenv("OLLAMA_MODELS", t.TempDir())
for _, tc := range testCases {
- t.Logf("Running Test: [%s]", tc.Name)
- u := httpSrv.URL + tc.Path
- req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
- assert.Nil(t, err)
+ t.Run(tc.Name, func(t *testing.T) {
+ u := httpSrv.URL + tc.Path
+ req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
+ assert.Nil(t, err)
- if tc.Setup != nil {
- tc.Setup(t, req)
- }
+ if tc.Setup != nil {
+ tc.Setup(t, req)
+ }
- resp, err := httpSrv.Client().Do(req)
- assert.Nil(t, err)
- defer resp.Body.Close()
-
- if tc.Expected != nil {
- tc.Expected(t, resp)
- }
+ resp, err := httpSrv.Client().Do(req)
+ assert.Nil(t, err)
+ defer resp.Body.Close()
+ if tc.Expected != nil {
+ tc.Expected(t, resp)
+ }
+ })
}
}
diff --git a/server/sched.go b/server/sched.go
new file mode 100644
index 00000000..c4a071c1
--- /dev/null
+++ b/server/sched.go
@@ -0,0 +1,553 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "reflect"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/format"
+ "github.com/ollama/ollama/gpu"
+ "github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/server/envconfig"
+ "golang.org/x/exp/slices"
+)
+
+type LlmRequest struct {
+ ctx context.Context //nolint:containedctx
+ model *Model
+ opts api.Options
+ sessionDuration time.Duration
+ successCh chan *runnerRef
+ errCh chan error
+}
+
+type Scheduler struct {
+ pendingReqCh chan *LlmRequest
+ finishedReqCh chan *LlmRequest
+ expiredCh chan *runnerRef
+ unloadedCh chan interface{}
+
+ loaded map[string]*runnerRef
+ loadedMu sync.Mutex
+
+ loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
+ newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+ getGpuFn func() gpu.GpuInfoList
+}
+
+var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
+
+func InitScheduler(ctx context.Context) *Scheduler {
+ sched := &Scheduler{
+ pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
+ finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
+ expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
+ unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests),
+ loaded: make(map[string]*runnerRef),
+ newServerFn: llm.NewLlamaServer,
+ getGpuFn: gpu.GetGPUInfo,
+ }
+ sched.loadFn = sched.load
+ return sched
+}
+
+// context must be canceled to decrement ref count and release the runner
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
+ // allocate a large enough kv cache for all parallel requests
+ opts.NumCtx = opts.NumCtx * envconfig.NumParallel
+
+ req := &LlmRequest{
+ ctx: c,
+ model: model,
+ opts: opts,
+ sessionDuration: sessionDuration,
+ successCh: make(chan *runnerRef),
+ errCh: make(chan error, 1),
+ }
+
+ select {
+ case s.pendingReqCh <- req:
+ default:
+ req.errCh <- ErrMaxQueue
+ }
+ return req.successCh, req.errCh
+}
+
+// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
+func (s *Scheduler) Run(ctx context.Context) {
+ slog.Debug("starting llm scheduler")
+ go func() {
+ s.processPending(ctx)
+ }()
+
+ go func() {
+ s.processCompleted(ctx)
+ }()
+}
+
+func (s *Scheduler) processPending(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ slog.Debug("shutting down scheduler pending loop")
+ return
+ case pending := <-s.pendingReqCh:
+ // Block other requests until we get this pending request running
+
+ if pending.ctx.Err() != nil {
+ slog.Debug("pending request cancelled or timed out, skipping scheduling")
+ continue
+ }
+
+ for {
+ var runnerToExpire *runnerRef
+ s.loadedMu.Lock()
+ runner := s.loaded[pending.model.ModelPath]
+ loadedCount := len(s.loaded)
+ s.loadedMu.Unlock()
+ if runner != nil {
+ if runner.needsReload(ctx, pending) {
+ runnerToExpire = runner
+ } else {
+ // Runner is usable, return it
+ pending.useLoadedRunner(runner, s.finishedReqCh)
+ break
+ }
+ } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners {
+ slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
+ runnerToExpire = s.findRunnerToUnload()
+ } else {
+ // Either no models are loaded or below envconfig.MaxRunners
+ // Get a refreshed GPU list
+ gpus := s.getGpuFn()
+
+ // Load model for fitting
+ ggml, err := llm.LoadModel(pending.model.ModelPath)
+ if err != nil {
+ pending.errCh <- err
+ break
+ }
+
+ // If we're CPU only mode, just limit by envconfig.MaxRunners above
+ // TODO handle system memory exhaustion
+ if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
+ slog.Debug("cpu mode with existing models, loading")
+ s.loadFn(pending, ggml, gpus)
+ break
+ }
+
+ // No models loaded. Load the model but prefer the best fit.
+ if loadedCount == 0 {
+ slog.Debug("loading first model", "model", pending.model.ModelPath)
+ g := pickBestFitGPUs(pending, ggml, gpus)
+ if g != nil {
+ gpus = g
+ }
+ s.loadFn(pending, ggml, gpus)
+ break
+ }
+
+ // More than one loaded model, so we have to see if the new one fits
+ // Update free memory from currently loaded models
+ s.updateFreeSpace(gpus)
+ gpus = pickBestFitGPUs(pending, ggml, gpus)
+ if gpus != nil {
+ slog.Debug("new model fits with existing models, loading")
+ s.loadFn(pending, ggml, gpus)
+ break
+ }
+ runnerToExpire = s.findRunnerToUnload()
+ }
+
+ if runnerToExpire == nil {
+ // Shouildn't happen
+ slog.Error("runner to expire was nil!")
+ continue
+ }
+ // Trigger an expiration to unload once it's done
+ runnerToExpire.refMu.Lock()
+ slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
+ if runnerToExpire.expireTimer != nil {
+ runnerToExpire.expireTimer.Stop()
+ runnerToExpire.expireTimer = nil
+ }
+ runnerToExpire.sessionDuration = 0
+ if runnerToExpire.refCount <= 0 {
+ s.expiredCh <- runnerToExpire
+ }
+ runnerToExpire.refMu.Unlock()
+ // Wait for the unload to happen
+ // Note: at this point we're queueing up all incoming requests, even if they were for
+ // a different model that's loaded and not scheduled to be removed.
+ slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
+ select {
+ case <-ctx.Done():
+ slog.Debug("shutting down scheduler pending loop")
+ return
+ case <-s.unloadedCh:
+ slog.Debug("unload completed", "model", runnerToExpire.model)
+ continue
+ }
+ }
+ case <-s.unloadedCh:
+ // An unload request when there are no pending request can be ignored
+ slog.Debug("ignoring unload event with no pending requests")
+ }
+ }
+}
+
+func (s *Scheduler) processCompleted(ctx context.Context) {
+ // Process completed requests, expired timers, and unloading models
+ for {
+ select {
+ case <-ctx.Done():
+ slog.Debug("shutting down scheduler completed loop")
+ return
+ case finished := <-s.finishedReqCh:
+ s.loadedMu.Lock()
+ runner := s.loaded[finished.model.ModelPath]
+ s.loadedMu.Unlock()
+ if runner == nil {
+ slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
+ continue
+ }
+ runner.refMu.Lock()
+ runner.refCount--
+ if runner.refCount <= 0 {
+ if runner.sessionDuration <= 0 {
+ slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
+ if runner.expireTimer != nil {
+ runner.expireTimer.Stop()
+ runner.expireTimer = nil
+ }
+ s.expiredCh <- runner
+ } else if runner.expireTimer == nil {
+ slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
+ runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
+ slog.Debug("timer expired, expiring to unload", "model", runner.model)
+ runner.refMu.Lock()
+ defer runner.refMu.Unlock()
+ if runner.expireTimer != nil {
+ runner.expireTimer.Stop()
+ runner.expireTimer = nil
+ }
+ s.expiredCh <- runner
+ })
+ } else {
+ slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
+ runner.expireTimer.Reset(runner.sessionDuration)
+ }
+ }
+ slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
+ runner.refMu.Unlock()
+ case runner := <-s.expiredCh:
+ slog.Debug("runner expired event received", "model", runner.model)
+ runner.refMu.Lock()
+ if runner.refCount > 0 {
+ // Shouldn't happen, but safeguard to ensure no leaked runners
+ slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
+ go func(runner *runnerRef) {
+ // We can't unload yet, but want to as soon as the current request completes
+ // So queue up another expired event
+ time.Sleep(10 * time.Millisecond)
+ s.expiredCh <- runner
+ }(runner)
+ runner.refMu.Unlock()
+ continue
+ }
+
+ s.loadedMu.Lock()
+ slog.Debug("got lock to unload", "model", runner.model)
+ runner.unload()
+ delete(s.loaded, runner.model)
+ s.loadedMu.Unlock()
+ slog.Debug("runner released", "model", runner.model)
+ runner.refMu.Unlock()
+ slog.Debug("sending an unloaded event", "model", runner.model)
+ s.unloadedCh <- struct{}{}
+ }
+ }
+}
+
+// Complete the pending request and send the runner back to the requester
+// Wires up a finished event after the request context is completed
+// Updates session duration, and resets expiration timer
+func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
+ runner.refMu.Lock()
+ defer runner.refMu.Unlock()
+ runner.refCount++
+ if runner.expireTimer != nil {
+ runner.expireTimer.Stop()
+ runner.expireTimer = nil
+ }
+ runner.sessionDuration = pending.sessionDuration
+ pending.successCh <- runner
+ go func() {
+ <-pending.ctx.Done()
+ slog.Debug("context for request finished")
+ finished <- pending
+ }()
+}
+
+func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
+ llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
+ if err != nil {
+ // some older models are not compatible with newer versions of llama.cpp
+ // show a generalized compatibility error until there is a better way to
+ // check for model compatibility
+ if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
+ err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
+ }
+ slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
+ req.errCh <- err
+ return
+ }
+ runner := &runnerRef{}
+ runner.model = req.model.ModelPath
+ runner.adapters = req.model.AdapterPaths
+ runner.projectors = req.model.ProjectorPaths
+ runner.llama = llama
+ runner.Options = &req.opts
+ runner.sessionDuration = req.sessionDuration
+ runner.gpus = gpus
+ runner.estimatedVRAM = llama.EstimatedVRAM()
+ runner.loading = true
+ runner.refCount = 1
+ runner.refMu.Lock()
+ s.loadedMu.Lock()
+ s.loaded[req.model.ModelPath] = runner
+ slog.Info("loaded runners", "count", len(s.loaded))
+ s.loadedMu.Unlock()
+
+ go func() {
+ defer runner.refMu.Unlock()
+ if err = llama.WaitUntilRunning(req.ctx); err != nil {
+ slog.Error("error loading llama server", "error", err)
+ runner.refCount--
+ req.errCh <- err
+ slog.Debug("triggering expiration for failed load", "model", runner.model)
+ s.expiredCh <- runner
+ return
+ }
+ slog.Debug("finished setting up runner", "model", req.model.ModelPath)
+ runner.loading = false
+ go func() {
+ <-req.ctx.Done()
+ slog.Debug("context for request finished")
+ s.finishedReqCh <- req
+ }()
+ req.successCh <- runner
+ }()
+}
+
+func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
+ type predKey struct {
+ Library string
+ ID string
+ }
+ predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
+ s.loadedMu.Lock()
+ for _, r := range s.loaded {
+ r.refMu.Lock()
+ gpuIDs := make([]string, 0, len(r.gpus))
+ if r.llama != nil {
+
+ // TODO this should be broken down by GPU instead of assuming uniform spread
+ estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
+ for _, gpu := range r.gpus {
+ gpuIDs = append(gpuIDs, gpu.ID)
+ }
+ for _, gpu := range allGpus {
+ if slices.Contains(gpuIDs, gpu.ID) {
+ predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
+ }
+ }
+ } else {
+ slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
+ }
+ r.refMu.Unlock()
+ }
+ s.loadedMu.Unlock()
+
+ // Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
+ for i := range allGpus {
+ if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
+ slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
+ if p > allGpus[i].TotalMemory {
+ // Shouldn't happen
+ slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
+ allGpus[i].FreeMemory = 0
+ } else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
+ // TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
+ // and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
+ // after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
+ allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
+ }
+ slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
+ }
+ }
+}
+
+type runnerRef struct {
+ refMu sync.Mutex
+ // refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
+ refCount uint // prevent unloading if > 0
+ // unloading bool // set to true when we are trying to unload the runner
+
+ llama llm.LlamaServer
+ loading bool // True only during initial load, then false forever
+ gpus gpu.GpuInfoList // Recorded at time of provisioning
+ estimatedVRAM uint64
+
+ sessionDuration time.Duration
+ expireTimer *time.Timer
+
+ model string
+ adapters []string
+ projectors []string
+ *api.Options
+}
+
+// The refMu must already be held when calling unload
+func (runner *runnerRef) unload() {
+ if runner.expireTimer != nil {
+ runner.expireTimer.Stop()
+ runner.expireTimer = nil
+ }
+ if runner.llama != nil {
+ runner.llama.Close()
+ }
+ runner.llama = nil
+ runner.adapters = nil
+ runner.projectors = nil
+ runner.Options = nil
+ runner.gpus = nil
+}
+
+func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
+ slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
+ runner.refMu.Lock()
+ defer runner.refMu.Unlock()
+
+ timeout := 10 * time.Second
+ if runner.loading {
+ timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
+ }
+
+ if runner.Options == nil {
+ return true
+ }
+
+ // Don't reload runner if num_gpu=-1 was provided
+ optsExisting := runner.Options.Runner
+ optsNew := req.opts.Runner
+ if optsNew.NumGPU < 0 {
+ optsExisting.NumGPU = -1
+ optsNew.NumGPU = -1
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
+ !reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
+ !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
+ runner.llama.Ping(ctx) != nil {
+ return true
+ }
+
+ return false
+}
+
+type ByDuration []*runnerRef
+
+func (a ByDuration) Len() int { return len(a) }
+func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a ByDuration) Less(i, j int) bool {
+ // uint64 to turn negative time (never unload) to largest
+ return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
+}
+
+// TODO - future consideration to pick runners based on size
+// type BySize []*runnerRef
+// func (a BySize) Len() int { return len(a) }
+// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
+
+// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
+// If the model can not be fit fully within the available GPU(s) nil is returned
+func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
+ var estimatedVRAM uint64
+ for _, gl := range gpus.ByLibrary() {
+ var ok bool
+ sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
+
+ // TODO - potentially sort by performance capability, existing models loaded, etc.
+ // Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
+ sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
+
+ // First attempt to fit the model into a single GPU
+ for _, g := range sgl {
+ if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+ slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+ return []gpu.GpuInfo{g}
+ }
+ }
+
+ // TODO future refinements
+ // - if multiple Libraries, see if any single GPU in any Library will fit
+ // - try subsets of GPUs instead of just falling back to 1 or all in a family
+
+ // Now try all the GPUs
+ if ok, estimatedVRAM = llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+ slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
+ return gl
+ }
+ }
+ return nil
+}
+
+// findRunnerToUnload finds a runner to unload to make room for a new model
+func (s *Scheduler) findRunnerToUnload() *runnerRef {
+ s.loadedMu.Lock()
+ runnerList := make([]*runnerRef, 0, len(s.loaded))
+ for _, r := range s.loaded {
+ runnerList = append(runnerList, r)
+ }
+ s.loadedMu.Unlock()
+
+ // In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
+ // e.g., if we have multiple options, will one make room for the request?
+ sort.Sort(ByDuration(runnerList))
+
+ // First try to find a runner that's already idle
+ for _, runner := range runnerList {
+ runner.refMu.Lock()
+ rc := runner.refCount
+ runner.refMu.Unlock()
+ if rc == 0 {
+ slog.Debug("found an idle runner to unload")
+ return runner
+ }
+ }
+ // None appear idle, just wait for the one with the shortest duration
+ slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
+ return runnerList[0]
+}
+
+func (s *Scheduler) unloadAllRunners() {
+ s.loadedMu.Lock()
+ defer s.loadedMu.Unlock()
+ for model, runner := range s.loaded {
+ if runner.llama != nil {
+ slog.Debug("shutting down runner", "model", model)
+ runner.llama.Close()
+ }
+ }
+}
diff --git a/server/sched_test.go b/server/sched_test.go
new file mode 100644
index 00000000..7e4faa61
--- /dev/null
+++ b/server/sched_test.go
@@ -0,0 +1,601 @@
+package server
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "fmt"
+ "log/slog"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/app/lifecycle"
+ "github.com/ollama/ollama/format"
+ "github.com/ollama/ollama/gpu"
+ "github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/server/envconfig"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ os.Setenv("OLLAMA_DEBUG", "1")
+ lifecycle.InitLogging()
+}
+
+func TestInitScheduler(t *testing.T) {
+ ctx, done := context.WithCancel(context.Background())
+ defer done()
+ s := InitScheduler(ctx)
+ s.loadedMu.Lock()
+ require.NotNil(t, s.loaded)
+ s.loadedMu.Unlock()
+}
+
+func TestLoad(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer done()
+ s := InitScheduler(ctx)
+ var ggml *llm.GGML // value not used in tests
+ req := &LlmRequest{
+ ctx: ctx,
+ model: &Model{ModelPath: "foo"},
+ opts: api.DefaultOptions(),
+ successCh: make(chan *runnerRef, 1),
+ errCh: make(chan error, 1),
+ sessionDuration: 2,
+ }
+ // Fail to load model first
+ s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+ return nil, fmt.Errorf("something failed to load model blah")
+ }
+ gpus := gpu.GpuInfoList{}
+ s.load(req, ggml, gpus)
+ require.Len(t, req.successCh, 0)
+ require.Len(t, req.errCh, 1)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 0)
+ s.loadedMu.Unlock()
+ err := <-req.errCh
+ require.Contains(t, err.Error(), "this model may be incompatible")
+
+ server := &mockLlm{estimatedVRAM: 10}
+ s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+ return server, nil
+ }
+ s.load(req, ggml, gpus)
+ select {
+ case err := <-req.errCh:
+ require.NoError(t, err)
+ case resp := <-req.successCh:
+ require.Equal(t, uint64(10), resp.estimatedVRAM)
+ require.Equal(t, uint(1), resp.refCount)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 1)
+ s.loadedMu.Unlock()
+ }
+
+ req.model.ModelPath = "dummy_model_path"
+ server.waitResp = fmt.Errorf("wait failure")
+ s.load(req, ggml, gpus)
+ select {
+ case err := <-req.errCh:
+ require.Contains(t, err.Error(), "wait failure")
+ case resp := <-req.successCh:
+ t.Errorf("unexpected success %v", resp)
+ }
+ s.loadedMu.Lock()
+ runner := s.loaded["dummy_model_path"]
+ s.loadedMu.Unlock()
+ require.NotNil(t, runner)
+ require.Equal(t, uint(0), runner.refCount)
+ time.Sleep(1 * time.Millisecond)
+ require.Len(t, s.expiredCh, 1)
+}
+
+type bundle struct {
+ ctx context.Context //nolint:containedctx
+ ctxDone func()
+ srv *mockLlm
+ req *LlmRequest
+ ggml *llm.GGML
+}
+
+func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+ return scenario.srv, nil
+}
+
+func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
+ scenario := &bundle{}
+ scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
+ t.Helper()
+
+ f, err := os.CreateTemp(t.TempDir(), modelName)
+ assert.Nil(t, err)
+ defer f.Close()
+
+ gguf := llm.NewGGUFV3(binary.LittleEndian)
+ err = gguf.Encode(f, llm.KV{
+ "general.architecture": "llama",
+ "general.name": "name",
+ "llama.context_length": uint32(32),
+ "llama.embedding_length": uint32(4096),
+ "llama.block_count": uint32(1),
+ "llama.attention.head_count": uint32(32),
+ "llama.attention.head_count_kv": uint32(32),
+ "tokenizer.ggml.tokens": []string{" "},
+ "tokenizer.ggml.scores": []float32{0},
+ "tokenizer.ggml.token_type": []int32{0},
+ }, []llm.Tensor{
+ {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+ })
+ assert.Nil(t, err)
+
+ fname := f.Name()
+ model := &Model{Name: modelName, ModelPath: fname}
+ scenario.ggml, err = llm.LoadModel(model.ModelPath)
+ require.NoError(t, err)
+
+ scenario.req = &LlmRequest{
+ ctx: scenario.ctx,
+ model: model,
+ opts: api.DefaultOptions(),
+ sessionDuration: 5 * time.Millisecond,
+ successCh: make(chan *runnerRef, 1),
+ errCh: make(chan error, 1),
+ }
+ scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
+ return scenario
+}
+
+func TestRequests(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
+ defer done()
+
+ // Same model, same request
+ scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
+ scenario1a.req.sessionDuration = 0
+ scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
+ scenario1b.req.model = scenario1a.req.model
+ scenario1b.ggml = scenario1a.ggml
+ scenario1b.req.sessionDuration = 0
+
+ // simple reload of same model
+ scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
+ scenario2a.req.model = scenario1a.req.model
+ scenario2a.ggml = scenario1a.ggml
+
+ // Multiple loaded models
+ scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
+ scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
+ scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
+ scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
+ scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
+
+ s := InitScheduler(ctx)
+ s.getGpuFn = func() gpu.GpuInfoList {
+ g := gpu.GpuInfo{Library: "metal"}
+ g.TotalMemory = 24 * format.GigaByte
+ g.FreeMemory = 12 * format.GigaByte
+ return []gpu.GpuInfo{g}
+ }
+ s.newServerFn = scenario1a.newServer
+ slog.Info("scenario1a")
+ s.pendingReqCh <- scenario1a.req
+ require.Len(t, s.pendingReqCh, 1)
+ s.Run(ctx)
+ select {
+ case resp := <-scenario1a.req.successCh:
+ require.Equal(t, resp.llama, scenario1a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario1a.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+
+ // Same runner as first request due to not needing a reload
+ s.newServerFn = scenario1b.newServer
+ slog.Info("scenario1b")
+ s.pendingReqCh <- scenario1b.req
+ select {
+ case resp := <-scenario1b.req.successCh:
+ require.Equal(t, resp.llama, scenario1a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario1b.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+
+ // Trigger a reload
+ s.newServerFn = scenario2a.newServer
+ scenario2a.req.model.AdapterPaths = []string{"new"}
+ slog.Info("scenario2a")
+ s.pendingReqCh <- scenario2a.req
+ // finish first two requests, so model can reload
+ time.Sleep(1 * time.Millisecond)
+ scenario1a.ctxDone()
+ scenario1b.ctxDone()
+ select {
+ case resp := <-scenario2a.req.successCh:
+ require.Equal(t, resp.llama, scenario2a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario2a.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+
+ envconfig.MaxRunners = 1
+ s.newServerFn = scenario3a.newServer
+ slog.Info("scenario3a")
+ s.pendingReqCh <- scenario3a.req
+ // finish prior request, so new model can load
+ time.Sleep(1 * time.Millisecond)
+ scenario2a.ctxDone()
+ select {
+ case resp := <-scenario3a.req.successCh:
+ require.Equal(t, resp.llama, scenario3a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario3a.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 1)
+ s.loadedMu.Unlock()
+
+ envconfig.MaxRunners = 0
+ s.newServerFn = scenario3b.newServer
+ slog.Info("scenario3b")
+ s.pendingReqCh <- scenario3b.req
+ select {
+ case resp := <-scenario3b.req.successCh:
+ require.Equal(t, resp.llama, scenario3b.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario3b.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 2)
+ s.loadedMu.Unlock()
+
+ // This is a CPU load with NumGPU = 0 so it should load
+ s.newServerFn = scenario3c.newServer
+ slog.Info("scenario3c")
+ s.pendingReqCh <- scenario3c.req
+ select {
+ case resp := <-scenario3c.req.successCh:
+ require.Equal(t, resp.llama, scenario3c.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario3c.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 3)
+ s.loadedMu.Unlock()
+
+ // Try to load a model that wont fit
+ s.newServerFn = scenario3d.newServer
+ slog.Info("scenario3d")
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 3)
+ s.loadedMu.Unlock()
+ scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
+ time.Sleep(2 * time.Millisecond)
+ s.pendingReqCh <- scenario3d.req
+ // finish prior request, so new model can load
+ time.Sleep(6 * time.Millisecond)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 2)
+ s.loadedMu.Unlock()
+ scenario3b.ctxDone()
+ select {
+ case resp := <-scenario3d.req.successCh:
+ require.Equal(t, resp.llama, scenario3d.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, scenario3d.req.errCh, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 2)
+ s.loadedMu.Unlock()
+}
+
+func TestGetRunner(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer done()
+
+ // Same model, same request
+ scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+ scenario1a.req.sessionDuration = 0
+ scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
+ scenario1b.req.sessionDuration = 0
+ scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
+ scenario1c.req.sessionDuration = 0
+ envconfig.MaxQueuedRequests = 1
+ s := InitScheduler(ctx)
+ s.getGpuFn = func() gpu.GpuInfoList {
+ g := gpu.GpuInfo{Library: "metal"}
+ g.TotalMemory = 24 * format.GigaByte
+ g.FreeMemory = 12 * format.GigaByte
+ return []gpu.GpuInfo{g}
+ }
+ s.newServerFn = scenario1a.newServer
+ slog.Info("scenario1a")
+ successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+ require.Len(t, s.pendingReqCh, 1)
+ slog.Info("scenario1b")
+ successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
+ require.Len(t, s.pendingReqCh, 1)
+ require.Len(t, successCh1b, 0)
+ require.Len(t, errCh1b, 1)
+ err := <-errCh1b
+ require.Contains(t, err.Error(), "server busy")
+ s.Run(ctx)
+ select {
+ case resp := <-successCh1a:
+ require.Equal(t, resp.llama, scenario1a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, errCh1a, 0)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ scenario1a.ctxDone()
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 1)
+ s.loadedMu.Unlock()
+
+ scenario1c.req.model.ModelPath = "bad path"
+ slog.Info("scenario1c")
+ successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
+ // Starts in pending channel, then should be quickly processsed to return an error
+ time.Sleep(5 * time.Millisecond)
+ require.Len(t, successCh1c, 0)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 0)
+ s.loadedMu.Unlock()
+ require.Len(t, errCh1c, 1)
+ err = <-errCh1c
+ require.Contains(t, err.Error(), "bad path")
+ scenario1b.ctxDone()
+}
+
+// TODO - add one scenario that triggers the bogus finished event with positive ref count
+func TestPrematureExpired(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
+ defer done()
+
+ // Same model, same request
+ scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
+ s := InitScheduler(ctx)
+ s.getGpuFn = func() gpu.GpuInfoList {
+ g := gpu.GpuInfo{Library: "metal"}
+ g.TotalMemory = 24 * format.GigaByte
+ g.FreeMemory = 12 * format.GigaByte
+ return []gpu.GpuInfo{g}
+ }
+ s.newServerFn = scenario1a.newServer
+ successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+ require.Len(t, s.pendingReqCh, 1)
+ s.Run(ctx)
+ select {
+ case resp := <-successCh1a:
+ require.Equal(t, resp.llama, scenario1a.srv)
+ require.Len(t, s.pendingReqCh, 0)
+ require.Len(t, errCh1a, 0)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 1)
+ s.loadedMu.Unlock()
+ slog.Info("sending premature expired event now")
+ s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ time.Sleep(scenario1a.req.sessionDuration)
+ scenario1a.ctxDone()
+ time.Sleep(20 * time.Millisecond)
+ require.LessOrEqual(t, len(s.finishedReqCh), 1)
+ time.Sleep(10 * time.Millisecond)
+ require.Len(t, s.finishedReqCh, 0)
+ s.loadedMu.Lock()
+ require.Len(t, s.loaded, 0)
+ s.loadedMu.Unlock()
+
+ // also shouldn't happen in real life
+ s.finishedReqCh <- scenario1a.req
+ time.Sleep(5 * time.Millisecond)
+}
+
+func TestUseLoadedRunner(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ req := &LlmRequest{
+ ctx: ctx,
+ opts: api.DefaultOptions(),
+ successCh: make(chan *runnerRef, 1),
+ sessionDuration: 2,
+ }
+ finished := make(chan *LlmRequest)
+ llm1 := &mockLlm{}
+ r1 := &runnerRef{llama: llm1, sessionDuration: 1}
+ req.useLoadedRunner(r1, finished)
+ require.Equal(t, uint(1), r1.refCount)
+ require.Equal(t, time.Duration(2), r1.sessionDuration)
+ select {
+ case success := <-req.successCh:
+ require.Equal(t, r1, success)
+ case <-ctx.Done():
+ t.Errorf("timeout")
+ }
+ done()
+ fin := <-finished
+ require.Equal(t, req, fin)
+}
+
+func TestUpdateFreeSpace(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer done()
+ gpus := gpu.GpuInfoList{
+ {
+ Library: "a",
+ ID: "1",
+ },
+ {
+ Library: "a",
+ ID: "2",
+ },
+ }
+ gpus[0].TotalMemory = 1000
+ gpus[0].FreeMemory = 900
+ gpus[1].TotalMemory = 2000
+ gpus[1].FreeMemory = 1900
+ llm1 := &mockLlm{estimatedVRAM: 100}
+ llm2 := &mockLlm{estimatedVRAM: 200}
+ r1 := &runnerRef{llama: llm1, gpus: gpus}
+ r2 := &runnerRef{llama: llm2, gpus: gpus}
+
+ s := InitScheduler(ctx)
+ s.loadedMu.Lock()
+ s.loaded["a"] = r1
+ s.loaded["b"] = r2
+ s.loadedMu.Unlock()
+
+ s.updateFreeSpace(gpus)
+ require.Equal(t, uint64(850), gpus[0].FreeMemory)
+ require.Equal(t, uint64(1850), gpus[1].FreeMemory)
+}
+
+func TestFindRunnerToUnload(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer done()
+
+ r1 := &runnerRef{refCount: 1, sessionDuration: 1}
+ r2 := &runnerRef{sessionDuration: 2}
+
+ s := InitScheduler(ctx)
+ s.loadedMu.Lock()
+ s.loaded["a"] = r1
+ s.loaded["b"] = r2
+ s.loadedMu.Unlock()
+
+ resp := s.findRunnerToUnload()
+ require.Equal(t, r2, resp)
+ r2.refCount = 1
+ resp = s.findRunnerToUnload()
+ require.Equal(t, r1, resp)
+
+}
+
+func TestNeedsReload(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer done()
+
+ llm := &mockLlm{}
+ do := api.DefaultOptions()
+ runner := &runnerRef{
+ adapters: []string{"adapter1"},
+ projectors: []string{"projector1"},
+ Options: &do,
+ llama: llm,
+ }
+ req := &LlmRequest{
+ model: &Model{
+ AdapterPaths: []string{"adapter2"},
+ ProjectorPaths: []string{"projector2"},
+ },
+ opts: api.DefaultOptions(),
+ }
+ resp := runner.needsReload(ctx, req)
+ require.True(t, resp)
+ req.model.AdapterPaths = runner.adapters
+ resp = runner.needsReload(ctx, req)
+ require.True(t, resp)
+ req.model.ProjectorPaths = runner.projectors
+ runner.loading = true
+ req.opts.NumBatch = 1234
+ resp = runner.needsReload(ctx, req)
+ require.True(t, resp)
+ req.opts.NumBatch = runner.Options.NumBatch
+ llm.pingResp = fmt.Errorf("foo")
+ resp = runner.needsReload(ctx, req)
+ require.True(t, resp)
+ llm.pingResp = nil
+ resp = runner.needsReload(ctx, req)
+ require.False(t, resp)
+ req.opts.NumGPU = 99
+ resp = runner.needsReload(ctx, req)
+ require.True(t, resp)
+ req.opts.NumGPU = -1
+ resp = runner.needsReload(ctx, req)
+ require.False(t, resp)
+}
+
+func TestUnloadAllRunners(t *testing.T) {
+ ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer done()
+
+ llm1 := &mockLlm{}
+ llm2 := &mockLlm{}
+ s := InitScheduler(ctx)
+ s.unloadAllRunners()
+
+ r1 := &runnerRef{llama: llm1}
+ r2 := &runnerRef{llama: llm2}
+
+ s.loadedMu.Lock()
+ s.loaded["a"] = r1
+ s.loaded["b"] = r2
+ s.loadedMu.Unlock()
+ s.unloadAllRunners()
+
+ require.True(t, llm1.closeCalled)
+ require.True(t, llm2.closeCalled)
+}
+
+func TestUnload(t *testing.T) {
+ llm1 := &mockLlm{}
+ r1 := &runnerRef{llama: llm1}
+ r2 := &runnerRef{adapters: []string{"A"}}
+ r1.unload()
+ require.True(t, llm1.closeCalled)
+ r2.unload()
+ require.Nil(t, r2.adapters)
+}
+
+type mockLlm struct {
+ pingResp error
+ waitResp error
+ completionResp error
+ embeddingResp []float64
+ embeddingRespErr error
+ tokenizeResp []int
+ tokenizeRespErr error
+ detokenizeResp string
+ detonekizeRespErr error
+ closeResp error
+ closeCalled bool
+ estimatedVRAM uint64
+}
+
+func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
+func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
+func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
+ return s.completionResp
+}
+func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+ return s.embeddingResp, s.embeddingRespErr
+}
+func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
+ return s.tokenizeResp, s.tokenizeRespErr
+}
+func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
+ return s.detokenizeResp, s.detonekizeRespErr
+}
+func (s *mockLlm) Close() error {
+ s.closeCalled = true
+ return s.closeResp
+}
+func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
diff --git a/types/errtypes/errtypes.go b/types/errtypes/errtypes.go
new file mode 100644
index 00000000..e3a18d0b
--- /dev/null
+++ b/types/errtypes/errtypes.go
@@ -0,0 +1,18 @@
+// Package errtypes contains custom error types
+package errtypes
+
+import (
+ "fmt"
+ "strings"
+)
+
+const UnknownOllamaKeyErrMsg = "unknown ollama key"
+
+// TODO: This should have a structured response from the API
+type UnknownOllamaKey struct {
+ Key string
+}
+
+func (e *UnknownOllamaKey) Error() string {
+ return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
+}
diff --git a/types/model/digest.go b/types/model/digest.go
deleted file mode 100644
index d5a7a155..00000000
--- a/types/model/digest.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package model
-
-import (
- "log/slog"
- "strings"
- "unicode"
-)
-
-// Digest represents a digest of a model Manifest. It is a comparable value
-// type and is immutable.
-//
-// The zero Digest is not a valid digest.
-type Digest struct {
- s string
-}
-
-// Type returns the digest type of the digest.
-//
-// Example:
-//
-// ParseDigest("sha256-1234").Type() // returns "sha256"
-func (d Digest) Type() string {
- typ, _, _ := strings.Cut(d.s, "-")
- return typ
-}
-
-// String returns the digest in the form of "-", or the
-// empty string if the digest is invalid.
-func (d Digest) String() string { return d.s }
-
-// IsValid returns true if the digest is valid (not zero).
-//
-// A valid digest may be created only by ParseDigest, or
-// ParseName(name).Digest().
-func (d Digest) IsValid() bool { return d.s != "" }
-
-// LogValue implements slog.Value.
-func (d Digest) LogValue() slog.Value {
- return slog.StringValue(d.String())
-}
-
-var (
- _ slog.LogValuer = Digest{}
-)
-
-// ParseDigest parses a string in the form of "-" into a
-// Digest.
-func ParseDigest(s string) Digest {
- typ, digest, ok := strings.Cut(s, "-")
- if ok && isValidDigestType(typ) && isValidHex(digest) {
- return Digest{s: s}
- }
- return Digest{}
-}
-
-func isValidDigestType(s string) bool {
- if len(s) == 0 {
- return false
- }
- for _, r := range s {
- if !unicode.IsLower(r) && !unicode.IsDigit(r) {
- return false
- }
- }
- return true
-}
-
-func isValidHex(s string) bool {
- if len(s) == 0 {
- return false
- }
- for i := range s {
- c := s[i]
- if c < '0' || c > '9' && c < 'a' || c > 'f' {
- return false
- }
- }
- return true
-}
diff --git a/types/model/digest_test.go b/types/model/digest_test.go
deleted file mode 100644
index 5096a28a..00000000
--- a/types/model/digest_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package model
-
-import "testing"
-
-var testDigests = map[string]Digest{
- "": {},
- "sha256-1234": {s: "sha256-1234"},
- "sha256-5678": {s: "sha256-5678"},
- "blake2-9abc": {s: "blake2-9abc"},
- "-1234": {},
- "sha256-": {},
- "sha256-1234-5678": {},
- "sha256-P": {}, // invalid hex
- "sha256-1234P": {},
- "---": {},
-}
-
-func TestDigestParse(t *testing.T) {
- // Test cases.
- for s, want := range testDigests {
- got := ParseDigest(s)
- t.Logf("ParseDigest(%q) = %#v", s, got)
- if got != want {
- t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
- }
- }
-}
-
-func TestDigestString(t *testing.T) {
- // Test cases.
- for s, d := range testDigests {
- want := s
- if !d.IsValid() {
- want = ""
- }
- got := d.String()
- if got != want {
- t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
- }
-
- got = ParseDigest(s).String()
- if got != want {
- t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
- }
- }
-}
diff --git a/types/model/file.go b/types/model/file.go
new file mode 100644
index 00000000..ee398309
--- /dev/null
+++ b/types/model/file.go
@@ -0,0 +1,299 @@
+package model
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+)
+
+type File struct {
+ Commands []Command
+}
+
+func (f File) String() string {
+ var sb strings.Builder
+ for _, cmd := range f.Commands {
+ fmt.Fprintln(&sb, cmd.String())
+ }
+
+ return sb.String()
+}
+
+type Command struct {
+ Name string
+ Args string
+}
+
+func (c Command) String() string {
+ var sb strings.Builder
+ switch c.Name {
+ case "model":
+ fmt.Fprintf(&sb, "FROM %s", c.Args)
+ case "license", "template", "system", "adapter":
+ fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
+ case "message":
+ role, message, _ := strings.Cut(c.Args, ": ")
+ fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
+ default:
+ fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
+ }
+
+ return sb.String()
+}
+
+type state int
+
+const (
+ stateNil state = iota
+ stateName
+ stateValue
+ stateParameter
+ stateMessage
+ stateComment
+)
+
+var (
+ errMissingFrom = errors.New("no FROM line")
+ errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
+ errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
+)
+
+func ParseFile(r io.Reader) (*File, error) {
+ var cmd Command
+ var curr state
+ var b bytes.Buffer
+ var role string
+
+ var f File
+
+ br := bufio.NewReader(r)
+ for {
+ r, _, err := br.ReadRune()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ next, r, err := parseRuneForState(r, curr)
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ return nil, fmt.Errorf("%w: %s", err, b.String())
+ } else if err != nil {
+ return nil, err
+ }
+
+ // process the state transition, some transitions need to be intercepted and redirected
+ if next != curr {
+ switch curr {
+ case stateName:
+ if !isValidCommand(b.String()) {
+ return nil, errInvalidCommand
+ }
+
+ // next state sometimes depends on the current buffer value
+ switch s := strings.ToLower(b.String()); s {
+ case "from":
+ cmd.Name = "model"
+ case "parameter":
+ // transition to stateParameter which sets command name
+ next = stateParameter
+ case "message":
+ // transition to stateMessage which validates the message role
+ next = stateMessage
+ fallthrough
+ default:
+ cmd.Name = s
+ }
+ case stateParameter:
+ cmd.Name = b.String()
+ case stateMessage:
+ if !isValidMessageRole(b.String()) {
+ return nil, errInvalidMessageRole
+ }
+
+ role = b.String()
+ case stateComment, stateNil:
+ // pass
+ case stateValue:
+ s, ok := unquote(b.String())
+ if !ok || isSpace(r) {
+ if _, err := b.WriteRune(r); err != nil {
+ return nil, err
+ }
+
+ continue
+ }
+
+ if role != "" {
+ s = role + ": " + s
+ role = ""
+ }
+
+ cmd.Args = s
+ f.Commands = append(f.Commands, cmd)
+ }
+
+ b.Reset()
+ curr = next
+ }
+
+ if strconv.IsPrint(r) {
+ if _, err := b.WriteRune(r); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ // flush the buffer
+ switch curr {
+ case stateComment, stateNil:
+ // pass; nothing to flush
+ case stateValue:
+ s, ok := unquote(b.String())
+ if !ok {
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ if role != "" {
+ s = role + ": " + s
+ }
+
+ cmd.Args = s
+ f.Commands = append(f.Commands, cmd)
+ default:
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ for _, cmd := range f.Commands {
+ if cmd.Name == "model" {
+ return &f, nil
+ }
+ }
+
+ return nil, errMissingFrom
+}
+
+func parseRuneForState(r rune, cs state) (state, rune, error) {
+ switch cs {
+ case stateNil:
+ switch {
+ case r == '#':
+ return stateComment, 0, nil
+ case isSpace(r), isNewline(r):
+ return stateNil, 0, nil
+ default:
+ return stateName, r, nil
+ }
+ case stateName:
+ switch {
+ case isAlpha(r):
+ return stateName, r, nil
+ case isSpace(r):
+ return stateValue, 0, nil
+ default:
+ return stateNil, 0, errInvalidCommand
+ }
+ case stateValue:
+ switch {
+ case isNewline(r):
+ return stateNil, r, nil
+ case isSpace(r):
+ return stateNil, r, nil
+ default:
+ return stateValue, r, nil
+ }
+ case stateParameter:
+ switch {
+ case isAlpha(r), isNumber(r), r == '_':
+ return stateParameter, r, nil
+ case isSpace(r):
+ return stateValue, 0, nil
+ default:
+ return stateNil, 0, io.ErrUnexpectedEOF
+ }
+ case stateMessage:
+ switch {
+ case isAlpha(r):
+ return stateMessage, r, nil
+ case isSpace(r):
+ return stateValue, 0, nil
+ default:
+ return stateNil, 0, io.ErrUnexpectedEOF
+ }
+ case stateComment:
+ switch {
+ case isNewline(r):
+ return stateNil, 0, nil
+ default:
+ return stateComment, 0, nil
+ }
+ default:
+ return stateNil, 0, errors.New("")
+ }
+}
+
+func quote(s string) string {
+ if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
+ if strings.Contains(s, "\"") {
+ return `"""` + s + `"""`
+ }
+
+ return `"` + s + `"`
+ }
+
+ return s
+}
+
+func unquote(s string) (string, bool) {
+ // TODO: single quotes
+ if len(s) >= 3 && s[:3] == `"""` {
+ if len(s) >= 6 && s[len(s)-3:] == `"""` {
+ return s[3 : len(s)-3], true
+ }
+
+ return "", false
+ }
+
+ if len(s) >= 1 && s[0] == '"' {
+ if len(s) >= 2 && s[len(s)-1] == '"' {
+ return s[1 : len(s)-1], true
+ }
+
+ return "", false
+ }
+
+ return s, true
+}
+
+func isAlpha(r rune) bool {
+ return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
+}
+
+func isNumber(r rune) bool {
+ return r >= '0' && r <= '9'
+}
+
+func isSpace(r rune) bool {
+ return r == ' ' || r == '\t'
+}
+
+func isNewline(r rune) bool {
+ return r == '\r' || r == '\n'
+}
+
+func isValidMessageRole(role string) bool {
+ return role == "system" || role == "user" || role == "assistant"
+}
+
+func isValidCommand(cmd string) bool {
+ switch strings.ToLower(cmd) {
+ case "from", "license", "template", "system", "adapter", "parameter", "message":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/types/model/file_test.go b/types/model/file_test.go
new file mode 100644
index 00000000..8e71760c
--- /dev/null
+++ b/types/model/file_test.go
@@ -0,0 +1,511 @@
+package model
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestParseFileFile(t *testing.T) {
+ input := `
+FROM model1
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+`
+
+ reader := strings.NewReader(input)
+
+ modelfile, err := ParseFile(reader)
+ assert.NoError(t, err)
+
+ expectedCommands := []Command{
+ {Name: "model", Args: "model1"},
+ {Name: "adapter", Args: "adapter1"},
+ {Name: "license", Args: "MIT"},
+ {Name: "param1", Args: "value1"},
+ {Name: "param2", Args: "value2"},
+ {Name: "template", Args: "template1"},
+ }
+
+ assert.Equal(t, expectedCommands, modelfile.Commands)
+}
+
+func TestParseFileFrom(t *testing.T) {
+ var cases = []struct {
+ input string
+ expected []Command
+ err error
+ }{
+ {
+ "FROM foo",
+ []Command{{Name: "model", Args: "foo"}},
+ nil,
+ },
+ {
+ "FROM /path/to/model",
+ []Command{{Name: "model", Args: "/path/to/model"}},
+ nil,
+ },
+ {
+ "FROM /path/to/model/fp16.bin",
+ []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
+ nil,
+ },
+ {
+ "FROM llama3:latest",
+ []Command{{Name: "model", Args: "llama3:latest"}},
+ nil,
+ },
+ {
+ "FROM llama3:7b-instruct-q4_K_M",
+ []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
+ nil,
+ },
+ {
+ "", nil, errMissingFrom,
+ },
+ {
+ "PARAMETER param1 value1",
+ nil,
+ errMissingFrom,
+ },
+ {
+ "PARAMETER param1 value1\nFROM foo",
+ []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
+ nil,
+ },
+ }
+
+ for _, c := range cases {
+ t.Run("", func(t *testing.T) {
+ modelfile, err := ParseFile(strings.NewReader(c.input))
+ assert.ErrorIs(t, err, c.err)
+ if modelfile != nil {
+ assert.Equal(t, c.expected, modelfile.Commands)
+ }
+ })
+ }
+}
+
+func TestParseFileParametersMissingValue(t *testing.T) {
+ input := `
+FROM foo
+PARAMETER param1
+`
+
+ reader := strings.NewReader(input)
+
+ _, err := ParseFile(reader)
+ assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
+}
+
+func TestParseFileBadCommand(t *testing.T) {
+ input := `
+FROM foo
+BADCOMMAND param1 value1
+`
+ _, err := ParseFile(strings.NewReader(input))
+ assert.ErrorIs(t, err, errInvalidCommand)
+
+}
+
+func TestParseFileMessages(t *testing.T) {
+ var cases = []struct {
+ input string
+ expected []Command
+ err error
+ }{
+ {
+ `
+FROM foo
+MESSAGE system You are a file parser. Always parse things.
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "message", Args: "system: You are a file parser. Always parse things."},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+MESSAGE system You are a file parser. Always parse things.`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "message", Args: "system: You are a file parser. Always parse things."},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+MESSAGE system You are a file parser. Always parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "message", Args: "system: You are a file parser. Always parse things."},
+ {Name: "message", Args: "user: Hey there!"},
+ {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+MESSAGE system """
+You are a multiline file parser. Always parse things.
+"""
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`,
+ nil,
+ errInvalidMessageRole,
+ },
+ {
+ `
+FROM foo
+MESSAGE system
+`,
+ nil,
+ io.ErrUnexpectedEOF,
+ },
+ {
+ `
+FROM foo
+MESSAGE system`,
+ nil,
+ io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, c := range cases {
+ t.Run("", func(t *testing.T) {
+ modelfile, err := ParseFile(strings.NewReader(c.input))
+ assert.ErrorIs(t, err, c.err)
+ if modelfile != nil {
+ assert.Equal(t, c.expected, modelfile.Commands)
+ }
+ })
+ }
+}
+
+func TestParseFileQuoted(t *testing.T) {
+ var cases = []struct {
+ multiline string
+ expected []Command
+ err error
+ }{
+ {
+ `
+FROM foo
+SYSTEM """
+This is a
+multiline system.
+"""
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "\nThis is a\nmultiline system.\n"},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """
+This is a
+multiline system."""
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "\nThis is a\nmultiline system."},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """This is a
+multiline system."""
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "This is a\nmultiline system."},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """This is a multiline system."""
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "This is a multiline system."},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """This is a multiline system.""
+ `,
+ nil,
+ io.ErrUnexpectedEOF,
+ },
+ {
+ `
+FROM foo
+SYSTEM "
+ `,
+ nil,
+ io.ErrUnexpectedEOF,
+ },
+ {
+ `
+FROM foo
+SYSTEM """
+This is a multiline system with "quotes".
+"""
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """"""
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: ""},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM ""
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: ""},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM "'"
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: "'"},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+SYSTEM """''"'""'""'"'''''""'""'"""
+`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "system", Args: `''"'""'""'"'''''""'""'`},
+ },
+ nil,
+ },
+ {
+ `
+FROM foo
+TEMPLATE """
+{{ .Prompt }}
+"""`,
+ []Command{
+ {Name: "model", Args: "foo"},
+ {Name: "template", Args: "\n{{ .Prompt }}\n"},
+ },
+ nil,
+ },
+ }
+
+ for _, c := range cases {
+ t.Run("", func(t *testing.T) {
+ modelfile, err := ParseFile(strings.NewReader(c.multiline))
+ assert.ErrorIs(t, err, c.err)
+ if modelfile != nil {
+ assert.Equal(t, c.expected, modelfile.Commands)
+ }
+ })
+ }
+}
+
+func TestParseFileParameters(t *testing.T) {
+ var cases = map[string]struct {
+ name, value string
+ }{
+ "numa true": {"numa", "true"},
+ "num_ctx 1": {"num_ctx", "1"},
+ "num_batch 1": {"num_batch", "1"},
+ "num_gqa 1": {"num_gqa", "1"},
+ "num_gpu 1": {"num_gpu", "1"},
+ "main_gpu 1": {"main_gpu", "1"},
+ "low_vram true": {"low_vram", "true"},
+ "f16_kv true": {"f16_kv", "true"},
+ "logits_all true": {"logits_all", "true"},
+ "vocab_only true": {"vocab_only", "true"},
+ "use_mmap true": {"use_mmap", "true"},
+ "use_mlock true": {"use_mlock", "true"},
+ "num_thread 1": {"num_thread", "1"},
+ "num_keep 1": {"num_keep", "1"},
+ "seed 1": {"seed", "1"},
+ "num_predict 1": {"num_predict", "1"},
+ "top_k 1": {"top_k", "1"},
+ "top_p 1.0": {"top_p", "1.0"},
+ "tfs_z 1.0": {"tfs_z", "1.0"},
+ "typical_p 1.0": {"typical_p", "1.0"},
+ "repeat_last_n 1": {"repeat_last_n", "1"},
+ "temperature 1.0": {"temperature", "1.0"},
+ "repeat_penalty 1.0": {"repeat_penalty", "1.0"},
+ "presence_penalty 1.0": {"presence_penalty", "1.0"},
+ "frequency_penalty 1.0": {"frequency_penalty", "1.0"},
+ "mirostat 1": {"mirostat", "1"},
+ "mirostat_tau 1.0": {"mirostat_tau", "1.0"},
+ "mirostat_eta 1.0": {"mirostat_eta", "1.0"},
+ "penalize_newline true": {"penalize_newline", "true"},
+ "stop ### User:": {"stop", "### User:"},
+ "stop ### User: ": {"stop", "### User: "},
+ "stop \"### User:\"": {"stop", "### User:"},
+ "stop \"### User: \"": {"stop", "### User: "},
+ "stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
+ "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
+ "stop <|endoftext|>": {"stop", "<|endoftext|>"},
+ "stop <|eot_id|>": {"stop", "<|eot_id|>"},
+ "stop ": {"stop", ""},
+ }
+
+ for k, v := range cases {
+ t.Run(k, func(t *testing.T) {
+ var b bytes.Buffer
+ fmt.Fprintln(&b, "FROM foo")
+ fmt.Fprintln(&b, "PARAMETER", k)
+ modelfile, err := ParseFile(&b)
+ assert.NoError(t, err)
+
+ assert.Equal(t, []Command{
+ {Name: "model", Args: "foo"},
+ {Name: v.name, Args: v.value},
+ }, modelfile.Commands)
+ })
+ }
+}
+
+func TestParseFileComments(t *testing.T) {
+ var cases = []struct {
+ input string
+ expected []Command
+ }{
+ {
+ `
+# comment
+FROM foo
+ `,
+ []Command{
+ {Name: "model", Args: "foo"},
+ },
+ },
+ }
+
+ for _, c := range cases {
+ t.Run("", func(t *testing.T) {
+ modelfile, err := ParseFile(strings.NewReader(c.input))
+ assert.NoError(t, err)
+ assert.Equal(t, c.expected, modelfile.Commands)
+ })
+ }
+}
+
+func TestParseFileFormatParseFile(t *testing.T) {
+ var cases = []string{
+ `
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system You are a file parser. Always parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+ `
+FROM foo
+ADAPTER adapter1
+LICENSE MIT
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+ `
+FROM foo
+ADAPTER adapter1
+LICENSE """
+Very long and boring legal text.
+Blah blah blah.
+"Oh look, a quote!"
+"""
+
+PARAMETER param1 value1
+PARAMETER param2 value2
+TEMPLATE template1
+MESSAGE system """
+You are a store greeter. Always responsed with "Hello!".
+"""
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`,
+ `
+FROM foo
+SYSTEM ""
+`,
+ }
+
+ for _, c := range cases {
+ t.Run("", func(t *testing.T) {
+ modelfile, err := ParseFile(strings.NewReader(c))
+ assert.NoError(t, err)
+
+ modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
+ assert.NoError(t, err)
+
+ assert.Equal(t, modelfile, modelfile2)
+ })
+ }
+
+}
diff --git a/types/model/name.go b/types/model/name.go
index 9c56c49a..b79374c3 100644
--- a/types/model/name.go
+++ b/types/model/name.go
@@ -1,693 +1,425 @@
+// Package model contains types and utilities for parsing, validating, and
+// working with model names and digests.
package model
import (
"cmp"
+ "encoding/hex"
"errors"
"fmt"
- "hash/maphash"
- "io"
"log/slog"
"path/filepath"
- "slices"
"strings"
- "sync"
-
- "github.com/ollama/ollama/types/structs"
)
// Errors
var (
- // ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
- // used by this package, but are exported so that other packages can
- // use them, instead of defining their own errors for them.
- ErrInvalidName = errors.New("invalid model name")
- ErrIncompleteName = errors.New("incomplete model name")
- ErrInvalidDigest = errors.New("invalid digest")
+ // ErrUnqualifiedName represents an error where a name is not fully
+ // qualified. It is not used directly in this package, but is here
+ // to avoid other packages inventing their own error type.
+ // Additionally, it can be conveniently used via [Unqualified].
+ ErrUnqualifiedName = errors.New("unqualified name")
)
-// Defaults
-const (
- // MaskDefault is the default mask used by [Name.DisplayShortest].
- MaskDefault = "registry.ollama.ai/library/?:latest"
-
- // MaskNothing is a mask that masks nothing.
- MaskNothing = "?/?/?:?"
-
- // DefaultFill is the default fill used by [ParseName].
- FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
-
- // FillNothing is a fill that fills nothing.
- FillNothing = "?/?/?:?+?"
-)
-
-const MaxNamePartLen = 128
-
-type PartKind int
-
-// Levels of concreteness
-const (
- // Each value aligns with its index in the Name.parts array.
-
- PartHost PartKind = iota
- PartNamespace
- PartModel
- PartTag
- PartBuild
- PartDigest
-
- // NumParts is the number of parts in a Name. In this list, it must
- // follow the final part.
- NumParts
-
- PartExtraneous = -1
-)
-
-var kindNames = map[PartKind]string{
- PartHost: "Host",
- PartNamespace: "Namespace",
- PartModel: "Name",
- PartTag: "Tag",
- PartBuild: "Build",
- PartDigest: "Digest",
+// Unqualified is a helper function that returns an error with
+// ErrUnqualifiedName as the cause and the name as the message.
+func Unqualified(n Name) error {
+ return fmt.Errorf("%w: %s", ErrUnqualifiedName, n)
}
-func (k PartKind) String() string {
- return cmp.Or(kindNames[k], "Unknown")
+// MissingPart is used to indicate any part of a name that was "promised" by
+// the presence of a separator, but is missing.
+//
+// The value was chosen because it is deemed unlikely to be set by a user,
+// not a valid part name valid when checked by [Name.IsValid], and easy to
+// spot in logs.
+const MissingPart = "!MISSING!"
+
+const (
+ defaultHost = "registry.ollama.ai"
+ defaultNamespace = "library"
+ defaultTag = "latest"
+)
+
+// DefaultName returns a name with the default values for the host, namespace,
+// and tag parts. The model and digest parts are empty.
+//
+// - The default host is ("registry.ollama.ai")
+// - The default namespace is ("library")
+// - The default tag is ("latest")
+func DefaultName() Name {
+ return Name{
+ Host: defaultHost,
+ Namespace: defaultNamespace,
+ Tag: defaultTag,
+ }
}
-// Name is an opaque reference to a model. It holds the parts of a model
-// with the case preserved, but is not directly comparable with other Names
-// since model names can be represented with different casing depending on
-// the use case. For instance, "Mistral" and "mistral" are the same model
-// but each version may have come from different sources (e.g. copied from a
-// Web page, or from a file path).
+type partKind int
+
+const (
+ kindHost partKind = iota
+ kindNamespace
+ kindModel
+ kindTag
+ kindDigest
+)
+
+func (k partKind) String() string {
+ switch k {
+ case kindHost:
+ return "host"
+ case kindNamespace:
+ return "namespace"
+ case kindModel:
+ return "model"
+ case kindTag:
+ return "tag"
+ case kindDigest:
+ return "digest"
+ default:
+ return "unknown"
+ }
+}
+
+// Name is a structured representation of a model name string, as defined by
+// [ParseNameNoDefaults].
//
-// Valid Names can ONLY be constructed by calling [ParseName].
-//
-// A Name is valid if and only if is have a valid Model part. The other parts
-// are optional.
-//
-// A Name is considered "complete" if it has all parts present. To check if a
-// Name is complete, use [Name.IsComplete].
-//
-// To compare two names in a case-insensitive manner, use [Name.EqualFold].
-//
-// The parts of a Name are:
-//
-// - Host: the domain of the model (optional)
-// - Namespace: the namespace of the model (optional)
-// - Model: the name of the model (required)
-// - Tag: the tag of the model (optional)
-// - Build: the build of the model; usually the quantization or "file type" (optional)
-//
-// The parts can be obtained in their original form by calling [Name.Parts].
-//
-// To check if a Name has at minimum a valid model part, use [Name.IsValid].
+// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
+// is valid.
type Name struct {
- _ structs.Incomparable
- parts [NumParts]string // host, namespace, model, tag, build, digest
-
- // TODO(bmizerany): track offsets and hold s (raw string) here? We
- // could pack the offsets all into a single uint64 since the first
- // parts take less bits since their max offset is less than the max
- // offset of the next part. This would save a ton of bytes per Name
- // and mean zero allocations for String.
+ Host string
+ Namespace string
+ Model string
+ Tag string
+ RawDigest string
}
-// ParseName parses s into a Name, and returns the result of filling it with
-// defaults. The input string must be a valid string
-// representation of a model name in the form:
+// ParseName parses and assembles a Name from a name string. The
+// format of a valid name string is:
//
-// [host/][namespace/][:tag][+build][@-]
+// s:
+// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
+// { host } "/" { namespace } "/" { model } ":" { tag }
+// { host } "/" { namespace } "/" { model } "@" { digest }
+// { host } "/" { namespace } "/" { model }
+// { namespace } "/" { model } ":" { tag } "@" { digest }
+// { namespace } "/" { model } ":" { tag }
+// { namespace } "/" { model } "@" { digest }
+// { namespace } "/" { model }
+// { model } ":" { tag } "@" { digest }
+// { model } ":" { tag }
+// { model } "@" { digest }
+// { model }
+// "@" { digest }
+// host:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
+// length: [1, 350]
+// namespace:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
+// length: [1, 80]
+// model:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
+// length: [1, 80]
+// tag:
+// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
+// length: [1, 80]
+// digest:
+// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
+// length: [1, 80]
//
-// The name part is required, all others are optional. If a part is missing,
-// it is left empty in the returned Name. If a part is invalid, the zero Ref
-// value is returned.
+// Most users should use [ParseName] instead, unless need to support
+// different defaults than DefaultName.
//
-// The build part is normalized to uppercase.
-//
-// Examples of valid paths:
-//
-// "example.com/library/mistral:7b+x"
-// "example.com/eva/mistral:7b+Q4_0"
-// "mistral:7b+x"
-// "example.com/mike/mistral:latest+Q4_0"
-// "example.com/bruce/mistral:latest"
-// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
-//
-// Examples of invalid paths:
-//
-// "example.com/mistral:7b+"
-// "example.com/mistral:7b+Q4_0+"
-// "x/y/z/z:8n+I"
-// ""
-//
-// It returns the zero value if any part is invalid.
-//
-// # Fills
-//
-// For any valid s, the fill string is used to fill in missing parts of the
-// Name. The fill string must be a valid Name with the exception that any part
-// may be the string ("?"), which will not be considered for filling.
-func ParseName(s, fill string) Name {
- var r Name
- parts(s)(func(kind PartKind, part string) bool {
- if kind == PartDigest && !ParseDigest(part).IsValid() {
- r = Name{}
- return false
- }
- if kind == PartExtraneous || !isValidPart(kind, part) {
- r = Name{}
- return false
- }
- r.parts[kind] = part
- return true
- })
- if r.IsValid() || r.IsResolved() {
- return fillName(r, fill)
+// The name returned is not guaranteed to be valid. If it is not valid, the
+// field values are left in an undefined state. Use [Name.IsValid] to check
+// if the name is valid.
+func ParseName(s string) Name {
+ return Merge(ParseNameBare(s), DefaultName())
+}
+
+// ParseNameBare parses s as a name string and returns a Name. No merge with
+// [DefaultName] is performed.
+func ParseNameBare(s string) Name {
+ var n Name
+ var promised bool
+
+ s, n.RawDigest, promised = cutLast(s, "@")
+ if promised && n.RawDigest == "" {
+ n.RawDigest = MissingPart
}
- return Name{}
-}
-func parseMask(s string) Name {
- var r Name
- parts(s)(func(kind PartKind, part string) bool {
- if part == "?" {
- // mask part; treat as empty but valid
- return true
- }
- if !isValidPart(kind, part) {
- panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
- }
- r.parts[kind] = part
- return true
- })
- return r
-}
-
-func MustParseName(s, fill string) Name {
- r := ParseName(s, fill)
- if !r.IsValid() {
- panic("invalid Name: " + s)
+ // "/" is an illegal tag character, so we can use it to split the host
+ if strings.LastIndex(s, ":") > strings.LastIndex(s, "/") {
+ s, n.Tag, _ = cutPromised(s, ":")
}
- return r
-}
-// fillName fills in the missing parts of dst with the parts of src.
-//
-// The returned Name will only be valid if dst is valid.
-//
-// It skipps fill parts that are "?".
-func fillName(r Name, fill string) Name {
- fill = cmp.Or(fill, FillDefault)
- f := parseMask(fill)
- if fill != FillNothing && f.IsZero() {
- panic("invalid fill")
+ s, n.Model, promised = cutPromised(s, "/")
+ if !promised {
+ n.Model = s
+ return n
}
- for i := range r.parts {
- if f.parts[i] == "?" {
- continue
- }
- r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
+
+ s, n.Namespace, promised = cutPromised(s, "/")
+ if !promised {
+ n.Namespace = s
+ return n
}
- return r
-}
-// WithBuild returns a copy of r with the build set to the given string.
-func (r Name) WithBuild(build string) Name {
- r.parts[PartBuild] = build
- return r
-}
-
-func (r Name) WithDigest(digest Digest) Name {
- r.parts[PartDigest] = digest.String()
- return r
-}
-
-var mapHashSeed = maphash.MakeSeed()
-
-// MapHash returns a case insensitive hash for use in maps and equality
-// checks. For a convenient way to compare names, use [Name.EqualFold].
-//
-//nolint:errcheck
-func (r Name) MapHash() uint64 {
- // correctly hash the parts with case insensitive comparison
- var h maphash.Hash
- h.SetSeed(mapHashSeed)
- for _, part := range r.parts {
- // downcase the part for hashing
- for i := range part {
- c := part[i]
- if c >= 'A' && c <= 'Z' {
- c = c - 'A' + 'a'
- }
- h.WriteByte(c)
- }
+ scheme, host, ok := strings.Cut(s, "://")
+ if !ok {
+ host = scheme
}
- return h.Sum64()
+ n.Host = host
+
+ return n
}
-func (r Name) slice(from, to PartKind) Name {
- var v Name
- copy(v.parts[from:to+1], r.parts[from:to+1])
- return v
-}
-
-// DisplayShortest returns the shortest possible, masked display string in form:
+// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are
+// expected to be in the form:
//
-// [host/][/][:]
-//
-// # Masks
-//
-// The mask is a string that specifies which parts of the name to omit based
-// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
-// that are the same as the mask, moving from left to right until the first
-// unequal part is found. It then moves right to left until the first unequal
-// part is found. The result is the shortest possible display string.
-//
-// Unlike a [Name] the mask can contain "?" characters which are treated as
-// wildcards. A "?" will never match a part of the name, since a valid name
-// can never contain a "?" character.
-//
-// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
-// with ("registry.ollama.ai/library/?:latest") will produce the display string
-// ("mistral").
-//
-// If mask is the empty string, then [MaskDefault] is used.
-//
-// DisplayShortest panics if the mask is not the empty string, MaskNothing, and
-// invalid.
-//
-// # Builds
-//
-// For now, DisplayShortest does consider the build or return one in the
-// result. We can lift this restriction when needed.
-func (r Name) DisplayShortest(mask string) string {
- mask = cmp.Or(mask, MaskDefault)
- d := parseMask(mask)
- if mask != MaskNothing && r.IsZero() {
- panic("invalid Name")
+// { host } "/" { namespace } "/" { model } "/" { tag }
+func ParseNameFromFilepath(s string) (n Name) {
+ parts := strings.Split(s, string(filepath.Separator))
+ if len(parts) != 4 {
+ return Name{}
}
- for i := range PartTag {
- if !strings.EqualFold(r.parts[i], d.parts[i]) {
- break
- }
- r.parts[i] = ""
+
+ n.Host = parts[0]
+ n.Namespace = parts[1]
+ n.Model = parts[2]
+ n.Tag = parts[3]
+ if !n.IsFullyQualified() {
+ return Name{}
}
- for i := PartTag; i >= 0; i-- {
- if !strings.EqualFold(r.parts[i], d.parts[i]) {
- break
- }
- r.parts[i] = ""
+
+ return n
+}
+
+// Merge merges the host, namespace, and tag parts of the two names,
+// preferring the non-empty parts of a.
+func Merge(a, b Name) Name {
+ a.Host = cmp.Or(a.Host, b.Host)
+ a.Namespace = cmp.Or(a.Namespace, b.Namespace)
+ a.Tag = cmp.Or(a.Tag, b.Tag)
+ return a
+}
+
+// String returns the name string, in the format that [ParseNameNoDefaults]
+// accepts as valid, if [Name.IsValid] reports true; otherwise the empty
+// string is returned.
+func (n Name) String() string {
+ var b strings.Builder
+ if n.Host != "" {
+ b.WriteString(n.Host)
+ b.WriteByte('/')
}
- return r.slice(PartHost, PartTag).DisplayLong()
-}
-
-// DisplayLongest returns the result of r.DisplayShortest(MaskNothing).
-func (r Name) DisplayLongest() string {
- return r.DisplayShortest(MaskNothing)
-}
-
-var seps = [...]string{
- PartHost: "/",
- PartNamespace: "/",
- PartModel: ":",
- PartTag: "+",
- PartBuild: "@",
- PartDigest: "",
-}
-
-// WriteTo implements io.WriterTo. It writes the fullest possible display
-// string in form:
-//
-// //:+@-
-//
-// Missing parts and their separators are not written.
-//
-// The full digest is always prefixed with "@". That is if [Name.IsValid]
-// reports false and [Name.IsResolved] reports true, then the string is
-// returned as "@-".
-func (r Name) writeTo(w io.StringWriter) error {
- var partsWritten int
- for i := range r.parts {
- if r.parts[i] == "" {
- continue
- }
- if partsWritten > 0 || i == int(PartDigest) {
- if _, err := w.WriteString(seps[i-1]); err != nil {
- return err
- }
- }
- if _, err := w.WriteString(r.parts[i]); err != nil {
- return err
- }
- partsWritten++
+ if n.Namespace != "" {
+ b.WriteString(n.Namespace)
+ b.WriteByte('/')
+ }
+ b.WriteString(n.Model)
+ if n.Tag != "" {
+ b.WriteByte(':')
+ b.WriteString(n.Tag)
+ }
+ if n.RawDigest != "" {
+ b.WriteByte('@')
+ b.WriteString(n.RawDigest)
}
- return nil
-}
-
-var builderPool = sync.Pool{
- New: func() interface{} {
- return &strings.Builder{}
- },
-}
-
-// DisplayLong returns the fullest possible display string in form:
-//
-// //:+
-//
-// If any part is missing, it is omitted from the display string.
-func (r Name) DisplayLong() string {
- b := builderPool.Get().(*strings.Builder)
- defer builderPool.Put(b)
- b.Reset()
- b.Grow(50) // arbitrarily long enough for most names
- _ = r.writeTo(b)
return b.String()
}
-// GoString implements fmt.GoStringer. It returns a string suitable for
-// debugging and logging. It is similar to [Name.DisplayLong] but it always
-// returns a string that includes all parts of the Name, with missing parts
-// replaced with a ("?").
-func (r Name) GoString() string {
- for i := range r.parts {
- r.parts[i] = cmp.Or(r.parts[i], "?")
+// DisplayShort returns a short string version of the name.
+func (n Name) DisplayShortest() string {
+ var sb strings.Builder
+
+ if n.Host != defaultHost {
+ sb.WriteString(n.Host)
+ sb.WriteByte('/')
+ sb.WriteString(n.Namespace)
+ sb.WriteByte('/')
+ } else if n.Namespace != defaultNamespace {
+ sb.WriteString(n.Namespace)
+ sb.WriteByte('/')
}
- return r.DisplayLong()
+
+ // always include model and tag
+ sb.WriteString(n.Model)
+ sb.WriteString(":")
+ sb.WriteString(n.Tag)
+ return sb.String()
}
-// LogValue implements slog.Valuer.
-func (r Name) LogValue() slog.Value {
- return slog.StringValue(r.GoString())
-}
-
-// IsComplete reports whether the Name is fully qualified. That is it has a
-// domain, namespace, name, tag, and build.
-func (r Name) IsComplete() bool {
- return !slices.Contains(r.parts[:PartDigest], "")
-}
-
-// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
-// build part to be present.
-func (r Name) IsCompleteNoBuild() bool {
- return !slices.Contains(r.parts[:PartBuild], "")
-}
-
-// IsResolved reports true if the Name has a valid digest.
-//
-// It is possible to have a valid Name, or a complete Name that is not
-// resolved.
-func (r Name) IsResolved() bool {
- return r.Digest().IsValid()
-}
-
-// Digest returns the digest part of the Name, if any.
-//
-// If Digest returns a non-empty string, then [Name.IsResolved] will return
-// true, and digest is considered valid.
-func (r Name) Digest() Digest {
- // This was already validated by ParseName, so we can just return it.
- return Digest{r.parts[PartDigest]}
-}
-
-// EqualFold reports whether r and o are equivalent model names, ignoring
-// case.
-func (r Name) EqualFold(o Name) bool {
- return r.CompareFold(o) == 0
-}
-
-// CompareFold performs a case-insensitive cmp.Compare on r and o.
-//
-// This can be used with [slices.SortFunc].
-//
-// For simple equality checks, use [Name.EqualFold].
-func (r Name) CompareFold(o Name) int {
- return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
-}
-
-func compareFold(a, b string) int {
- return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
- return cmp.Compare(downcase(a), downcase(b))
- })
-}
-
-func downcase(r rune) rune {
- if r >= 'A' && r <= 'Z' {
- return r - 'A' + 'a'
- }
- return r
-}
-
-func (r Name) Host() string { return r.parts[PartHost] }
-func (r Name) Namespace() string { return r.parts[PartNamespace] }
-func (r Name) Model() string { return r.parts[PartModel] }
-func (r Name) Build() string { return r.parts[PartBuild] }
-func (r Name) Tag() string { return r.parts[PartTag] }
-
-// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
-// restrictions in the go1.22 iter package requiring the
-// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
-// which we are not yet ready to support.
-//
-// Once we are ready to support rangefunc, this can be removed and replaced
-// with the iter.Seq2 type.
-type iter_Seq2[A, B any] func(func(A, B) bool)
-
-// Parts returns a sequence of the parts of a Name string from most specific
-// to least specific.
-//
-// It normalizes the input string by removing "http://" and "https://" only.
-// No other normalizations are performed.
-func parts(s string) iter_Seq2[PartKind, string] {
- return func(yield func(PartKind, string) bool) {
- if strings.HasPrefix(s, "http://") {
- s = strings.TrimPrefix(s, "http://")
- } else {
- s = strings.TrimPrefix(s, "https://")
- }
-
- if len(s) > MaxNamePartLen || len(s) == 0 {
- return
- }
-
- numConsecutiveDots := 0
- partLen := 0
- state, j := PartDigest, len(s)
- for i := len(s) - 1; i >= 0; i-- {
- if partLen++; partLen > MaxNamePartLen {
- // catch a part that is too long early, so
- // we don't keep spinning on it, waiting for
- // an isInValidPart check which would scan
- // over it again.
- yield(state, s[i+1:j])
- return
- }
-
- switch s[i] {
- case '@':
- switch state {
- case PartDigest:
- if !yield(PartDigest, s[i+1:j]) {
- return
- }
- if i == 0 {
- // This is the form
- // "@" which is valid.
- //
- // We're done.
- return
- }
- state, j, partLen = PartBuild, i, 0
- default:
- yield(PartExtraneous, s[i+1:j])
- return
- }
- case '+':
- switch state {
- case PartBuild, PartDigest:
- if !yield(PartBuild, s[i+1:j]) {
- return
- }
- state, j, partLen = PartTag, i, 0
- default:
- yield(PartExtraneous, s[i+1:j])
- return
- }
- case ':':
- switch state {
- case PartTag, PartBuild, PartDigest:
- if !yield(PartTag, s[i+1:j]) {
- return
- }
- state, j, partLen = PartModel, i, 0
- case PartHost:
- // noop: support for host:port
- default:
- yield(PartExtraneous, s[i+1:j])
- return
- }
- case '/':
- switch state {
- case PartModel, PartTag, PartBuild, PartDigest:
- if !yield(PartModel, s[i+1:j]) {
- return
- }
- state, j = PartNamespace, i
- case PartNamespace:
- if !yield(PartNamespace, s[i+1:j]) {
- return
- }
- state, j, partLen = PartHost, i, 0
- default:
- yield(PartExtraneous, s[i+1:j])
- return
- }
- default:
- if s[i] == '.' {
- if numConsecutiveDots++; numConsecutiveDots > 1 {
- yield(state, "")
- return
- }
- } else {
- numConsecutiveDots = 0
- }
- }
- }
-
- if state <= PartNamespace {
- yield(state, s[:j])
- } else {
- yield(PartModel, s[:j])
- }
- }
-}
-
-func (r Name) IsZero() bool {
- return r.parts == [NumParts]string{}
-}
-
-// IsValid reports if a model has at minimum a valid model part.
-func (r Name) IsValid() bool {
- // Parts ensures we only have valid parts, so no need to validate
- // them here, only check if we have a name or not.
- return r.parts[PartModel] != ""
-}
-
-// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically,
-// it trims any leading "/" and then calls [ParseName] with fill.
-func ParseNameFromURLPath(s, fill string) Name {
- s = strings.TrimPrefix(s, "/")
- return ParseName(s, fill)
-}
-
-// URLPath returns a complete, canonicalized, relative URL path using the parts of a
-// complete Name.
-//
-// The parts maintain their original case.
-//
-// Example:
-//
-// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
-func (r Name) URLPath() string {
- return r.DisplayShortest(MaskNothing)
-}
-
-// ParseNameFromFilepath parses a file path into a Name. The input string must be a
-// valid file path representation of a model name in the form:
-//
-// host/namespace/model/tag/build
-//
-// The zero valid is returned if s does not contain all path elements
-// leading up to the model part, or if any path element is an invalid part
-// for the its corresponding part kind.
-//
-// The fill string is used to fill in missing parts of any constructed Name.
-// See [ParseName] for more information on the fill string.
-func ParseNameFromFilepath(s, fill string) Name {
- var r Name
- for i := range PartBuild + 1 {
- part, rest, _ := strings.Cut(s, string(filepath.Separator))
- if !isValidPart(i, part) {
- return Name{}
- }
- r.parts[i] = part
- s = rest
- if s == "" {
- break
- }
- }
- if s != "" {
- return Name{}
- }
- if !r.IsValid() {
- return Name{}
- }
- return fillName(r, fill)
-}
-
-// Filepath returns a complete, canonicalized, relative file path using the
-// parts of a complete Name.
-//
-// Each parts is downcased, except for the build part which is upcased.
-//
-// Example:
-//
-// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD"
-func (r Name) Filepath() string {
- for i := range r.parts {
- if PartKind(i) == PartBuild {
- r.parts[i] = strings.ToUpper(r.parts[i])
- } else {
- r.parts[i] = strings.ToLower(r.parts[i])
- }
- }
- return filepath.Join(r.parts[:]...)
-}
-
-// FilepathNoBuild returns a complete, canonicalized, relative file path using
-// the parts of a complete Name, but without the build part.
-func (r Name) FilepathNoBuild() string {
- for i := range PartBuild {
- r.parts[i] = strings.ToLower(r.parts[i])
- }
- return filepath.Join(r.parts[:PartBuild]...)
-}
-
-// isValidPart reports if s contains all valid characters for the given
-// part kind.
-func isValidPart(kind PartKind, s string) bool {
- if s == "" {
+// IsValid reports whether all parts of the name are present and valid. The
+// digest is a special case, and is checked for validity only if present.
+func (n Name) IsValid() bool {
+ if n.RawDigest != "" && !isValidPart(kindDigest, n.RawDigest) {
return false
}
- var consecutiveDots int
- for _, c := range []byte(s) {
- if c == '.' {
- if consecutiveDots++; consecutiveDots >= 2 {
- return false
- }
- } else {
- consecutiveDots = 0
- }
- if !isValidByteFor(kind, c) {
+ return n.IsFullyQualified()
+}
+
+// IsFullyQualified returns true if all parts of the name are present and
+// valid without the digest.
+func (n Name) IsFullyQualified() bool {
+ var parts = []string{
+ n.Host,
+ n.Namespace,
+ n.Model,
+ n.Tag,
+ }
+ for i, part := range parts {
+ if !isValidPart(partKind(i), part) {
return false
}
}
return true
}
-func isValidByteFor(kind PartKind, c byte) bool {
- if kind == PartNamespace && c == '.' {
+// Filepath returns a canonical filepath that represents the name with each part from
+// host to tag as a directory in the form:
+//
+// {host}/{namespace}/{model}/{tag}
+//
+// It uses the system's filepath separator and ensures the path is clean.
+//
+// It panics if the name is not fully qualified. Use [Name.IsFullyQualified]
+// to check if the name is fully qualified.
+func (n Name) Filepath() string {
+ if !n.IsFullyQualified() {
+ panic("illegal attempt to get filepath of invalid name")
+ }
+ return filepath.Join(
+ strings.ToLower(filepath.Join(
+ n.Host,
+ n.Namespace,
+ n.Model,
+ )),
+ n.Tag,
+ )
+}
+
+// LogValue returns a slog.Value that represents the name as a string.
+func (n Name) LogValue() slog.Value {
+ return slog.StringValue(n.String())
+}
+
+func isValidLen(kind partKind, s string) bool {
+ switch kind {
+ case kindHost:
+ return len(s) >= 1 && len(s) <= 350
+ case kindTag:
+ return len(s) >= 1 && len(s) <= 80
+ default:
+ return len(s) >= 1 && len(s) <= 80
+ }
+}
+
+func isValidPart(kind partKind, s string) bool {
+ if !isValidLen(kind, s) {
return false
}
- if kind == PartHost && c == ':' {
- return true
+ for i := range s {
+ if i == 0 {
+ if !isAlphanumericOrUnderscore(s[i]) {
+ return false
+ }
+ continue
+ }
+ switch s[i] {
+ case '_', '-':
+ case '.':
+ if kind == kindNamespace {
+ return false
+ }
+ case ':':
+ if kind != kindHost && kind != kindDigest {
+ return false
+ }
+ default:
+ if !isAlphanumericOrUnderscore(s[i]) {
+ return false
+ }
+ }
}
- if c == '.' || c == '-' {
- return true
- }
- if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
- return true
- }
- return false
+ return true
+}
+
+func isAlphanumericOrUnderscore(c byte) bool {
+ return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
+}
+
+func cutLast(s, sep string) (before, after string, ok bool) {
+ i := strings.LastIndex(s, sep)
+ if i >= 0 {
+ return s[:i], s[i+len(sep):], true
+ }
+ return s, "", false
+}
+
+// cutPromised cuts the last part of s at the last occurrence of sep. If sep is
+// found, the part before and after sep are returned as-is unless empty, in
+// which case they are returned as MissingPart, which will cause
+// [Name.IsValid] to return false.
+func cutPromised(s, sep string) (before, after string, ok bool) {
+ before, after, ok = cutLast(s, sep)
+ if !ok {
+ return before, after, false
+ }
+ return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
+}
+
+type DigestType byte
+
+const (
+ DigestTypeInvalid DigestType = iota
+ DigestTypeSHA256
+)
+
+func (t DigestType) String() string {
+ switch t {
+ case DigestTypeSHA256:
+ return "sha256"
+ default:
+ return "invalid"
+ }
+}
+
+type Digest struct {
+ Type DigestType
+ Sum [32]byte
+}
+
+func ParseDigest(s string) (Digest, error) {
+ i := strings.IndexAny(s, "-:")
+ if i < 0 {
+ return Digest{}, fmt.Errorf("invalid digest %q", s)
+ }
+ typ, encSum := s[:i], s[i+1:]
+ if typ != "sha256" {
+ return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
+ }
+ d := Digest{
+ Type: DigestTypeSHA256,
+ }
+ n, err := hex.Decode(d.Sum[:], []byte(encSum))
+ if err != nil {
+ return Digest{}, err
+ }
+ if n != 32 {
+ return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
+ }
+ return d, nil
+}
+
+func (d Digest) String() string {
+ if d.Type == DigestTypeInvalid {
+ return ""
+ }
+ return fmt.Sprintf("sha256-%x", d.Sum)
+}
+
+func (d Digest) IsValid() bool {
+ return d.Type != DigestTypeInvalid
}
diff --git a/types/model/name_test.go b/types/model/name_test.go
index 8749477a..fb584291 100644
--- a/types/model/name_test.go
+++ b/types/model/name_test.go
@@ -1,709 +1,387 @@
package model
import (
- "bytes"
- "cmp"
- "fmt"
- "log/slog"
"path/filepath"
- "slices"
- "strings"
+ "reflect"
+ "runtime"
"testing"
)
-type fields struct {
- host, namespace, model, tag, build string
- digest string
-}
+const (
+ part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
+ part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
+)
-func fieldsFromName(p Name) fields {
- return fields{
- host: p.parts[PartHost],
- namespace: p.parts[PartNamespace],
- model: p.parts[PartModel],
- tag: p.parts[PartTag],
- build: p.parts[PartBuild],
- digest: p.parts[PartDigest],
- }
-}
-
-var testNames = map[string]fields{
- "mistral:latest": {model: "mistral", tag: "latest"},
- "mistral": {model: "mistral"},
- "mistral:30B": {model: "mistral", tag: "30B"},
- "mistral:7b": {model: "mistral", tag: "7b"},
- "mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
- "mistral+KQED": {model: "mistral", build: "KQED"},
- "mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
- "mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
- "llama2": {model: "llama2"},
- "user/model": {namespace: "user", model: "model"},
- "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
- "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
- "localhost:5000/ns/mistral": {host: "localhost:5000", namespace: "ns", model: "mistral"},
-
- // invalid digest
- "mistral:latest@invalid256-": {},
- "mistral:latest@-123": {},
- "mistral:latest@!-123": {},
- "mistral:latest@1-!": {},
- "mistral:latest@": {},
-
- // resolved
- "x@sha123-1": {model: "x", digest: "sha123-1"},
- "@sha456-2": {digest: "sha456-2"},
-
- "@@sha123-1": {},
-
- // preserves case for build
- "x+b": {model: "x", build: "b"},
-
- // invalid (includes fuzzing trophies)
- " / / : + ": {},
- " / : + ": {},
- " : + ": {},
- " + ": {},
- " : ": {},
- " / ": {},
- " /": {},
- "/ ": {},
- "/": {},
- ":": {},
- "+": {},
-
- // (".") in namepsace is not allowed
- "invalid.com/7b+x": {},
-
- "invalid:7b+Q4_0:latest": {},
- "in valid": {},
- "invalid/y/z/foo": {},
- "/0": {},
- "0 /0": {},
- "0 /": {},
- "0/": {},
- ":/0": {},
- "+0/00000": {},
- "0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
- "0//0": {},
- "m+^^^": {},
- "file:///etc/passwd": {},
- "file:///etc/passwd:latest": {},
- "file:///etc/passwd:latest+u": {},
-
- ":x": {},
- "+x": {},
- "x+": {},
-
- // Disallow ("\.+") in any part to prevent path traversal anywhere
- // we convert the name to a path.
- "../etc/passwd": {},
- ".../etc/passwd": {},
- "./../passwd": {},
- "./0+..": {},
-
- strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
- strings.Repeat("a", MaxNamePartLen+1): {},
-}
-
-// TestConsecutiveDots tests that consecutive dots are not allowed in any
-// part, to avoid path traversal. There also are some tests in testNames, but
-// this test is more exhaustive and exists to emphasize the importance of
-// preventing path traversal.
-func TestNameConsecutiveDots(t *testing.T) {
- for i := 1; i < 10; i++ {
- s := strings.Repeat(".", i)
- if i > 1 {
- if g := ParseName(s, FillNothing).DisplayLong(); g != "" {
- t.Errorf("ParseName(%q) = %q; want empty string", s, g)
- }
- } else {
- if g := ParseName(s, FillNothing).DisplayLong(); g != s {
- t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
- }
- }
- }
-}
-
-func TestNameParts(t *testing.T) {
- var p Name
- if w, g := int(NumParts), len(p.parts); w != g {
- t.Errorf("Parts() = %d; want %d", g, w)
- }
-}
-
-func TestNamePartString(t *testing.T) {
- if g := PartKind(-2).String(); g != "Unknown" {
- t.Errorf("Unknown part = %q; want %q", g, "Unknown")
- }
- for kind, name := range kindNames {
- if g := kind.String(); g != name {
- t.Errorf("%s = %q; want %q", kind, g, name)
- }
- }
-}
-
-func TestParseName(t *testing.T) {
- for baseName, want := range testNames {
- for _, prefix := range []string{"", "https://", "http://"} {
- // We should get the same results with or without the
- // http(s) prefixes
- s := prefix + baseName
-
- t.Run(s, func(t *testing.T) {
- name := ParseName(s, FillNothing)
- got := fieldsFromName(name)
- if got != want {
- t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
- }
-
- // test round-trip
- if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) {
- t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName)
- }
- })
- }
- }
-}
-
-func TestParseNameFill(t *testing.T) {
- cases := []struct {
- in string
- fill string
- want string
- }{
- {"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
- {"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
- {"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
-
- // Invalid
- {"", "example.com/library/?:latest+Q4_0", ""},
- {"llama2:?", "example.com/library/?:latest+Q4_0", ""},
- }
-
- for _, tt := range cases {
- t.Run(tt.in, func(t *testing.T) {
- name := ParseName(tt.in, tt.fill)
- if g := name.DisplayLong(); g != tt.want {
- t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
- }
- })
- }
-
- t.Run("invalid fill", func(t *testing.T) {
- defer func() {
- if recover() == nil {
- t.Fatal("expected panic")
- }
- }()
- ParseName("x", "^")
- })
-}
-
-func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
- cases := []string{
- "http://https://valid.com/valid/valid:latest",
- "https://http://valid.com/valid/valid:latest",
- }
- for _, s := range cases {
- t.Run(s, func(t *testing.T) {
- name := ParseName(s, FillNothing)
- if name.IsValid() {
- t.Errorf("expected invalid path; got %#v", name)
- }
- })
- }
-
-}
-
-func TestCompleteWithAndWithoutBuild(t *testing.T) {
+func TestParseNameParts(t *testing.T) {
cases := []struct {
in string
- complete bool
- completeNoBuild bool
+ want Name
+ wantFilepath string
+ wantValidDigest bool
}{
- {"", false, false},
- {"incomplete/mistral:7b+x", false, false},
- {"incomplete/mistral:7b+Q4_0", false, false},
- {"incomplete:7b+x", false, false},
- {"complete.com/x/mistral:latest+Q4_0", true, true},
- {"complete.com/x/mistral:latest", false, true},
+ {
+ in: "registry.ollama.ai/library/dolphin-mistral:7b-v2.6-dpo-laser-q6_K",
+ want: Name{
+ Host: "registry.ollama.ai",
+ Namespace: "library",
+ Model: "dolphin-mistral",
+ Tag: "7b-v2.6-dpo-laser-q6_K",
+ },
+ wantFilepath: filepath.Join("registry.ollama.ai", "library", "dolphin-mistral", "7b-v2.6-dpo-laser-q6_K"),
+ },
+ {
+ in: "scheme://host:port/namespace/model:tag",
+ want: Name{
+ Host: "host:port",
+ Namespace: "namespace",
+ Model: "model",
+ Tag: "tag",
+ },
+ wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
+ },
+ {
+ in: "host/namespace/model:tag",
+ want: Name{
+ Host: "host",
+ Namespace: "namespace",
+ Model: "model",
+ Tag: "tag",
+ },
+ wantFilepath: filepath.Join("host", "namespace", "model", "tag"),
+ },
+ {
+ in: "host:port/namespace/model:tag",
+ want: Name{
+ Host: "host:port",
+ Namespace: "namespace",
+ Model: "model",
+ Tag: "tag",
+ },
+ wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
+ },
+ {
+ in: "host/namespace/model",
+ want: Name{
+ Host: "host",
+ Namespace: "namespace",
+ Model: "model",
+ },
+ wantFilepath: filepath.Join("host", "namespace", "model", "latest"),
+ },
+ {
+ in: "host:port/namespace/model",
+ want: Name{
+ Host: "host:port",
+ Namespace: "namespace",
+ Model: "model",
+ },
+ wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"),
+ },
+ {
+ in: "namespace/model",
+ want: Name{
+ Namespace: "namespace",
+ Model: "model",
+ },
+ wantFilepath: filepath.Join("registry.ollama.ai", "namespace", "model", "latest"),
+ },
+ {
+ in: "model",
+ want: Name{
+ Model: "model",
+ },
+ wantFilepath: filepath.Join("registry.ollama.ai", "library", "model", "latest"),
+ },
+ {
+ in: "h/nn/mm:t",
+ want: Name{
+ Host: "h",
+ Namespace: "nn",
+ Model: "mm",
+ Tag: "t",
+ },
+ wantFilepath: filepath.Join("h", "nn", "mm", "t"),
+ },
+ {
+ in: part80 + "/" + part80 + "/" + part80 + ":" + part80,
+ want: Name{
+ Host: part80,
+ Namespace: part80,
+ Model: part80,
+ Tag: part80,
+ },
+ wantFilepath: filepath.Join(part80, part80, part80, part80),
+ },
+ {
+ in: part350 + "/" + part80 + "/" + part80 + ":" + part80,
+ want: Name{
+ Host: part350,
+ Namespace: part80,
+ Model: part80,
+ Tag: part80,
+ },
+ wantFilepath: filepath.Join(part350, part80, part80, part80),
+ },
+ {
+ in: "@digest",
+ want: Name{
+ RawDigest: "digest",
+ },
+ wantValidDigest: false,
+ },
+ {
+ in: "model@sha256:123",
+ want: Name{
+ Model: "model",
+ RawDigest: "sha256:123",
+ },
+ wantValidDigest: true,
+ },
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
- p := ParseName(tt.in, FillNothing)
- t.Logf("ParseName(%q) = %#v", tt.in, p)
- if g := p.IsComplete(); g != tt.complete {
- t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
+ got := ParseNameBare(tt.in)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
}
- if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
- t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
- }
- })
- }
- // Complete uses Parts which returns a slice, but it should be
- // inlined when used in Complete, preventing any allocations or
- // escaping to the heap.
- allocs := testing.AllocsPerRun(1000, func() {
- keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
- })
- if allocs > 0 {
- t.Errorf("Complete allocs = %v; want 0", allocs)
- }
-}
-
-func TestNameLogValue(t *testing.T) {
- cases := []string{
- "example.com/library/mistral:latest+Q4_0",
- "mistral:latest",
- "mistral:7b+Q4_0",
- }
- for _, s := range cases {
- t.Run(s, func(t *testing.T) {
- var b bytes.Buffer
- log := slog.New(slog.NewTextHandler(&b, nil))
- name := ParseName(s, FillNothing)
- log.Info("", "name", name)
- want := fmt.Sprintf("name=%s", name.GoString())
- got := b.String()
- if !strings.Contains(got, want) {
- t.Errorf("expected log output to contain %q; got %q", want, got)
+ got = ParseName(tt.in)
+ if tt.wantFilepath != "" && got.Filepath() != tt.wantFilepath {
+ t.Errorf("parseName(%q).Filepath() = %q; want %q", tt.in, got.Filepath(), tt.wantFilepath)
}
})
}
}
-func TestNameGoString(t *testing.T) {
+var testCases = map[string]bool{ // name -> valid
+ "": false,
+
+ "_why/_the/_lucky:_stiff": true,
+
+ // minimal
+ "h/n/m:t@d": true,
+
+ "host/namespace/model:tag": true,
+ "host/namespace/model": false,
+ "namespace/model": false,
+ "model": false,
+ "@sha256-1000000000000000000000000000000000000000000000000000000000000000": false,
+ "model@sha256-1000000000000000000000000000000000000000000000000000000000000000": false,
+ "model@sha256:1000000000000000000000000000000000000000000000000000000000000000": false,
+
+ // long (but valid)
+ part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
+ part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
+
+ "h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
+ "h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
+
+ // unqualified
+ "m": false,
+ "n/m:": false,
+ "h/n/m": false,
+ "@t": false,
+ "m@d": false,
+
+ // invalids
+ "^": false,
+ "mm:": false,
+ "/nn/mm": false,
+ "//": false,
+ "//mm": false,
+ "hh//": false,
+ "//mm:@": false,
+ "00@": false,
+ "@": false,
+
+ // not starting with alphanum
+ "-hh/nn/mm:tt@dd": false,
+ "hh/-nn/mm:tt@dd": false,
+ "hh/nn/-mm:tt@dd": false,
+ "hh/nn/mm:-tt@dd": false,
+ "hh/nn/mm:tt@-dd": false,
+
+ // hosts
+ "host:https/namespace/model:tag": true,
+
+ // colon in non-host part before tag
+ "host/name:space/model:tag": false,
+}
+
+func TestNameparseNameDefault(t *testing.T) {
+ const name = "xx"
+ n := ParseName(name)
+ got := n.String()
+ want := "registry.ollama.ai/library/xx:latest"
+ if got != want {
+ t.Errorf("parseName(%q).String() = %q; want %q", name, got, want)
+ }
+}
+
+func TestNameIsValid(t *testing.T) {
+ var numStringTests int
+ for s, want := range testCases {
+ n := ParseNameBare(s)
+ got := n.IsValid()
+ if got != want {
+ t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
+ }
+
+ // Test roundtrip with String
+ if got {
+ got := ParseNameBare(s).String()
+ if got != s {
+ t.Errorf("parseName(%q).String() = %q; want %q", s, got, s)
+ }
+ numStringTests++
+ }
+ }
+
+ if numStringTests == 0 {
+ t.Errorf("no tests for Name.String")
+ }
+}
+
+func TestNameIsValidPart(t *testing.T) {
cases := []struct {
- name string
- in string
- wantString string
- wantGoString string // default is tt.in
+ kind partKind
+ s string
+ want bool
}{
- {
- name: "Complete Name",
- in: "example.com/library/mistral:latest+Q4_0",
- wantGoString: "example.com/library/mistral:latest+Q4_0@?",
- },
- {
- name: "Short Name",
- in: "mistral:latest",
- wantGoString: "?/?/mistral:latest+?@?",
- },
- {
- name: "Long Name",
- in: "library/mistral:latest",
- wantGoString: "?/library/mistral:latest+?@?",
- },
- {
- name: "Case Preserved",
- in: "Library/Mistral:Latest",
- wantGoString: "?/Library/Mistral:Latest+?@?",
- },
- {
- name: "With digest",
- in: "Library/Mistral:Latest@sha256-123456",
- wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
- },
+ {kind: kindHost, s: "", want: false},
+ {kind: kindHost, s: "a", want: true},
+ {kind: kindHost, s: "a.", want: true},
+ {kind: kindHost, s: "a.b", want: true},
+ {kind: kindHost, s: "a:123", want: true},
+ {kind: kindHost, s: "a:123/aa/bb", want: false},
+ {kind: kindNamespace, s: "bb", want: true},
+ {kind: kindNamespace, s: "a.", want: false},
+ {kind: kindModel, s: "-h", want: false},
+ {kind: kindDigest, s: "sha256-1000000000000000000000000000000000000000000000000000000000000000", want: true},
+ }
+ for _, tt := range cases {
+ t.Run(tt.s, func(t *testing.T) {
+ got := isValidPart(tt.kind, tt.s)
+ if got != tt.want {
+ t.Errorf("isValidPart(%s, %q) = %v; want %v", tt.kind, tt.s, got, tt.want)
+ }
+ })
}
+}
+
+func TestFilepathAllocs(t *testing.T) {
+ n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
+ allocs := testing.AllocsPerRun(1000, func() {
+ n.Filepath()
+ })
+ var allowedAllocs float64 = 3
+ if runtime.GOOS == "windows" {
+ allowedAllocs = 5
+ }
+ if allocs > allowedAllocs {
+ t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
+ }
+}
+
+const (
+ validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
+ validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
+)
+
+func TestParseDigest(t *testing.T) {
+ cases := []struct {
+ in string
+ want string
+ }{
+ {"", ""}, // empty
+ {"sha123-12", ""}, // invalid type
+ {"sha256-", ""}, // invalid sum
+ {"sha256-123", ""}, // invalid odd length sum
+
+ {validSha256, validSha256},
+ {validSha256Old, validSha256},
+ }
for _, tt := range cases {
- t.Run(tt.name, func(t *testing.T) {
- p := ParseName(tt.in, FillNothing)
- tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
- if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
- t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
+ t.Run(tt.in, func(t *testing.T) {
+ got, err := ParseDigest(tt.in)
+ if err != nil {
+ if tt.want != "" {
+ t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
+ }
+ return
+ }
+ if got.String() != tt.want {
+ t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
}
})
}
}
-func TestDisplayLongest(t *testing.T) {
- g := ParseName("example.com/library/mistral:latest+Q4_0", FillNothing).DisplayLongest()
- if g != "example.com/library/mistral:latest" {
- t.Errorf("got = %q; want %q", g, "example.com/library/mistral:latest")
+func TestParseNameFromFilepath(t *testing.T) {
+ cases := map[string]Name{
+ filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},
+ filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"},
+ filepath.Join("namespace", "model", "tag"): {},
+ filepath.Join("model", "tag"): {},
+ filepath.Join("model"): {},
+ filepath.Join("..", "..", "model", "tag"): {},
+ filepath.Join("", "namespace", ".", "tag"): {},
+ filepath.Join(".", ".", ".", "."): {},
+ filepath.Join("/", "path", "to", "random", "file"): {},
+ }
+
+ for in, want := range cases {
+ t.Run(in, func(t *testing.T) {
+ got := ParseNameFromFilepath(in)
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("parseNameFromFilepath(%q) = %v; want %v", in, got, want)
+ }
+ })
}
}
func TestDisplayShortest(t *testing.T) {
- cases := []struct {
- in string
- mask string
- want string
- wantPanic bool
- }{
- {"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
- {"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
- {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
- {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
-
- // case-insensitive
- {"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
- {"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
- {"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
- {"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
- {"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
-
- // zero value
- {"", MaskDefault, "", true},
-
- // invalid mask
- {"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
-
- // DefaultMask
- {"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
-
- // Auto-Fill
- {"x", "example.com/library/_:latest", "x", false},
- {"x", "example.com/library/_:latest+Q4_0", "x", false},
- {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
- {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
+ cases := map[string]string{
+ "registry.ollama.ai/library/model:latest": "model:latest",
+ "registry.ollama.ai/library/model:tag": "model:tag",
+ "registry.ollama.ai/namespace/model:tag": "namespace/model:tag",
+ "host/namespace/model:tag": "host/namespace/model:tag",
+ "host/library/model:tag": "host/library/model:tag",
}
- for _, tt := range cases {
- t.Run("", func(t *testing.T) {
- defer func() {
- if tt.wantPanic {
- if recover() == nil {
- t.Errorf("expected panic")
- }
+ for in, want := range cases {
+ t.Run(in, func(t *testing.T) {
+ got := ParseNameBare(in).DisplayShortest()
+ if got != want {
+ t.Errorf("parseName(%q).DisplayShortest() = %q; want %q", in, got, want)
+ }
+ })
+ }
+}
+
+func FuzzName(f *testing.F) {
+ for s := range testCases {
+ f.Add(s)
+ }
+ f.Fuzz(func(t *testing.T, s string) {
+ n := ParseNameBare(s)
+ if n.IsValid() {
+ parts := [...]string{n.Host, n.Namespace, n.Model, n.Tag, n.RawDigest}
+ for _, part := range parts {
+ if part == ".." {
+ t.Errorf("unexpected .. as valid part")
+ }
+ if len(part) > 350 {
+ t.Errorf("part too long: %q", part)
}
- }()
-
- p := ParseName(tt.in, FillNothing)
- t.Logf("ParseName(%q) = %#v", tt.in, p)
- if g := p.DisplayShortest(tt.mask); g != tt.want {
- t.Errorf("got = %q; want %q", g, tt.want)
}
- })
- }
-}
-
-func TestParseNameAllocs(t *testing.T) {
- allocs := testing.AllocsPerRun(1000, func() {
- keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
- })
- if allocs > 0 {
- t.Errorf("ParseName allocs = %v; want 0", allocs)
- }
-}
-
-func BenchmarkParseName(b *testing.B) {
- b.ReportAllocs()
-
- for range b.N {
- keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
- }
-}
-
-func FuzzParseNameFromFilepath(f *testing.F) {
- f.Add("example.com/library/mistral/7b/Q4_0")
- f.Add("example.com/../mistral/7b/Q4_0")
- f.Add("example.com/x/../7b/Q4_0")
- f.Add("example.com/x/../7b")
- f.Fuzz(func(t *testing.T, s string) {
- name := ParseNameFromFilepath(s, FillNothing)
- if strings.Contains(s, "..") && !name.IsZero() {
- t.Fatalf("non-zero value for path with '..': %q", s)
- }
- if name.IsValid() == name.IsZero() {
- t.Errorf("expected valid path to be non-zero value; got %#v", name)
+ if n.String() != s {
+ t.Errorf("String() = %q; want %q", n.String(), s)
+ }
}
+
})
}
-
-func FuzzParseName(f *testing.F) {
- f.Add("example.com/mistral:7b+Q4_0")
- f.Add("example.com/mistral:7b+q4_0")
- f.Add("example.com/mistral:7b+x")
- f.Add("x/y/z:8n+I")
- f.Add(":x")
- f.Add("@sha256-123456")
- f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
- f.Add(":@!@")
- f.Add("...")
- f.Fuzz(func(t *testing.T, s string) {
- r0 := ParseName(s, FillNothing)
-
- if strings.Contains(s, "..") && !r0.IsZero() {
- t.Fatalf("non-zero value for path with '..': %q", s)
- }
-
- if !r0.IsValid() && !r0.IsResolved() {
- if !r0.EqualFold(Name{}) {
- t.Errorf("expected invalid path to be zero value; got %#v", r0)
- }
- t.Skipf("invalid path: %q", s)
- }
-
- for _, p := range r0.parts {
- if len(p) > MaxNamePartLen {
- t.Errorf("part too long: %q", p)
- }
- }
-
- if !strings.EqualFold(r0.DisplayLong(), s) {
- t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s)
- }
-
- r1 := ParseName(r0.DisplayLong(), FillNothing)
- if !r0.EqualFold(r1) {
- t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
- }
- })
-}
-
-func TestNameStringAllocs(t *testing.T) {
- name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
- allocs := testing.AllocsPerRun(1000, func() {
- keep(name.DisplayLong())
- })
- if allocs > 1 {
- t.Errorf("String allocs = %v; want 0", allocs)
- }
-}
-
-func TestNamePath(t *testing.T) {
- cases := []struct {
- in string
- want string
- }{
- {"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"},
-
- // incomplete
- {"example.com/library/mistral:latest", "example.com/library/mistral:latest"},
- {"", ""},
- }
- for _, tt := range cases {
- t.Run(tt.in, func(t *testing.T) {
- p := ParseName(tt.in, FillNothing)
- t.Logf("ParseName(%q) = %#v", tt.in, p)
- if g := p.URLPath(); g != tt.want {
- t.Errorf("got = %q; want %q", g, tt.want)
- }
- })
- }
-}
-
-func TestNameFilepath(t *testing.T) {
- cases := []struct {
- in string
- want string
- wantNoBuild string
- }{
- {
- in: "example.com/library/mistral:latest+Q4_0",
- want: "example.com/library/mistral/latest/Q4_0",
- wantNoBuild: "example.com/library/mistral/latest",
- },
- {
- in: "Example.Com/Library/Mistral:Latest+Q4_0",
- want: "example.com/library/mistral/latest/Q4_0",
- wantNoBuild: "example.com/library/mistral/latest",
- },
- {
- in: "Example.Com/Library/Mistral:Latest+Q4_0",
- want: "example.com/library/mistral/latest/Q4_0",
- wantNoBuild: "example.com/library/mistral/latest",
- },
- {
- in: "example.com/library/mistral:latest",
- want: "example.com/library/mistral/latest",
- wantNoBuild: "example.com/library/mistral/latest",
- },
- {
- in: "",
- want: "",
- wantNoBuild: "",
- },
- }
- for _, tt := range cases {
- t.Run(tt.in, func(t *testing.T) {
- p := ParseName(tt.in, FillNothing)
- t.Logf("ParseName(%q) = %#v", tt.in, p)
- g := p.Filepath()
- g = filepath.ToSlash(g)
- if g != tt.want {
- t.Errorf("got = %q; want %q", g, tt.want)
- }
- g = p.FilepathNoBuild()
- g = filepath.ToSlash(g)
- if g != tt.wantNoBuild {
- t.Errorf("got = %q; want %q", g, tt.wantNoBuild)
- }
- })
- }
-}
-
-func TestParseNameFilepath(t *testing.T) {
- cases := []struct {
- in string
- fill string // default is FillNothing
- want string
- }{
- {
- in: "example.com/library/mistral/latest/Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "example.com/library/mistral/latest",
- fill: "?/?/?:latest+Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "example.com/library/mistral",
- fill: "?/?/?:latest+Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "example.com/library",
- want: "",
- },
- {
- in: "example.com/",
- want: "",
- },
- {
- in: "example.com/^/mistral/latest/Q4_0",
- want: "",
- },
- {
- in: "example.com/library/mistral/../Q4_0",
- want: "",
- },
- {
- in: "example.com/library/mistral/latest/Q4_0/extra",
- want: "",
- },
- }
- for _, tt := range cases {
- t.Run(tt.in, func(t *testing.T) {
- in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator))
- fill := cmp.Or(tt.fill, FillNothing)
- want := ParseName(tt.want, fill)
- if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) {
- t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
- }
- })
- }
-}
-
-func TestParseNameFromPath(t *testing.T) {
- cases := []struct {
- in string
- want string
- fill string // default is FillNothing
- }{
- {
- in: "example.com/library/mistral:latest+Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "/example.com/library/mistral:latest+Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "/example.com/library/mistral",
- want: "example.com/library/mistral",
- },
- {
- in: "/example.com/library/mistral",
- fill: "?/?/?:latest+Q4_0",
- want: "example.com/library/mistral:latest+Q4_0",
- },
- {
- in: "/example.com/library",
- want: "",
- },
- {
- in: "/example.com/",
- want: "",
- },
- {
- in: "/example.com/^/mistral/latest",
- want: "",
- },
- }
- for _, tt := range cases {
- t.Run(tt.in, func(t *testing.T) {
- fill := cmp.Or(tt.fill, FillNothing)
- if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want {
- t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
- }
- })
- }
-}
-
-func ExampleName_MapHash() {
- m := map[uint64]bool{}
-
- // key 1
- m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
- m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
- m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
-
- // key 2
- m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
-
- fmt.Println(len(m))
- // Output:
- // 2
-}
-
-func ExampleName_CompareFold_sort() {
- names := []Name{
- ParseName("mistral:latest", FillNothing),
- ParseName("mistRal:7b+q4", FillNothing),
- ParseName("MIstral:7b", FillNothing),
- }
-
- slices.SortFunc(names, Name.CompareFold)
-
- for _, n := range names {
- fmt.Println(n.DisplayLong())
- }
-
- // Output:
- // MIstral:7b
- // mistRal:7b+q4
- // mistral:latest
-}
-
-func ExampleName_completeAndResolved() {
- for _, s := range []string{
- "x/y/z:latest+q4_0@sha123-1",
- "x/y/z:latest+q4_0",
- "@sha123-1",
- } {
- name := ParseName(s, FillNothing)
- fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
- }
-
- // Output:
- // complete:true resolved:true digest:sha123-1
- // complete:true resolved:false digest:
- // complete:false resolved:true digest:sha123-1
-}
-
-func ExampleName_DisplayShortest() {
- name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
-
- fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
- fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
- fmt.Println(name.DisplayShortest("example.com/_/_:_"))
- fmt.Println(name.DisplayShortest("_/_/_:_"))
-
- // Default
- name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
- fmt.Println(name.DisplayShortest(""))
-
- // Output:
- // mistral
- // jmorganca/mistral
- // jmorganca/mistral:latest
- // example.com/jmorganca/mistral:latest
- // mistral
-}
-
-func keep[T any](v T) T { return v }
diff --git a/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa b/types/model/testdata/fuzz/FuzzName/d37463aa416f6bab
similarity index 53%
rename from types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
rename to types/model/testdata/fuzz/FuzzName/d37463aa416f6bab
index 0cdf1eac..0034d9f5 100644
--- a/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
+++ b/types/model/testdata/fuzz/FuzzName/d37463aa416f6bab
@@ -1,2 +1,2 @@
go test fuzz v1
-string("/0")
+string("00@")
diff --git a/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 b/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
deleted file mode 100644
index c5d09a4c..00000000
--- a/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string("0//0")
diff --git a/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d b/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
deleted file mode 100644
index 880ce7a3..00000000
--- a/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string("0 /0")
diff --git a/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab b/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
deleted file mode 100644
index fa981c52..00000000
--- a/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string("+0/00000")
diff --git a/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 b/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
deleted file mode 100644
index 0a66beb8..00000000
--- a/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string(":")
diff --git a/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 b/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
deleted file mode 100644
index db07727d..00000000
--- a/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")
diff --git a/types/structs/structs.go b/types/structs/structs.go
deleted file mode 100644
index 52929ebf..00000000
--- a/types/structs/structs.go
+++ /dev/null
@@ -1,15 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-// Package structs contains the Incomparable type.
-package structs
-
-// Incomparable is a zero-width incomparable type. If added as the
-// first field in a struct, it marks that struct as not comparable
-// (can't do == or be a map key) and usually doesn't add any width to
-// the struct (unless the struct has only small fields).
-//
-// By making a struct incomparable, you can prevent misuse (prevent
-// people from using ==), but also you can shrink generated binaries,
-// as the compiler can omit equality funcs from the binary.
-type Incomparable [0]func()