Merge pull request #6128 from ollama/mxyng/lint
enable gofmt/gofumpt/goimports/tenv
This commit is contained in:
commit
77ccbf04dc
68 changed files with 199 additions and 149 deletions
1
.gitattributes
vendored
1
.gitattributes
vendored
|
@ -1 +1,2 @@
|
||||||
llm/ext_server/* linguist-vendored
|
llm/ext_server/* linguist-vendored
|
||||||
|
* text eol=lf
|
||||||
|
|
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
|
@ -273,7 +273,7 @@ jobs:
|
||||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||||
- uses: golangci/golangci-lint-action@v6
|
- uses: golangci/golangci-lint-action@v6
|
||||||
with:
|
with:
|
||||||
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }}
|
args: --timeout 8m0s -v
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|
|
@ -7,22 +7,32 @@ linters:
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- contextcheck
|
- contextcheck
|
||||||
|
- errcheck
|
||||||
- exportloopref
|
- exportloopref
|
||||||
|
- gci
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
# conditionally enable this on linux/macos
|
- gofmt
|
||||||
# - gofmt
|
- gofumpt
|
||||||
# - goimports
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
- intrange
|
- intrange
|
||||||
|
- makezero
|
||||||
- misspell
|
- misspell
|
||||||
- nilerr
|
- nilerr
|
||||||
- nolintlint
|
- nolintlint
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
|
- staticcheck
|
||||||
|
- tenv
|
||||||
- testifylint
|
- testifylint
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
|
- usestdlibvars
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
- usestdlibvars
|
linters-settings:
|
||||||
|
gci:
|
||||||
|
sections: [standard, default, localmodule]
|
||||||
severity:
|
severity:
|
||||||
default-severity: error
|
default-severity: error
|
||||||
rules:
|
rules:
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -172,7 +173,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||||
}
|
}
|
||||||
|
|
||||||
if errorResponse.Error != "" {
|
if errorResponse.Error != "" {
|
||||||
return fmt.Errorf(errorResponse.Error)
|
return errors.New(errorResponse.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
|
|
@ -2,7 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -192,7 +192,7 @@ func TestUseMmapFormatParams(t *testing.T) {
|
||||||
"use_mmap": {"foo"},
|
"use_mmap": {"foo"},
|
||||||
},
|
},
|
||||||
exp: nil,
|
exp: nil,
|
||||||
err: fmt.Errorf("invalid bool value [foo]"),
|
err: errors.New("invalid bool value [foo]"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
package lifecycle
|
package lifecycle
|
||||||
|
|
||||||
import "fmt"
|
import "errors"
|
||||||
|
|
||||||
func GetStarted() error {
|
func GetStarted() error {
|
||||||
return fmt.Errorf("GetStarted not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,6 @@ func GetStarted() error {
|
||||||
Sys: &syscall.SysProcAttr{CreationFlags: CREATE_NEW_CONSOLE, HideWindow: false},
|
Sys: &syscall.SysProcAttr{CreationFlags: CREATE_NEW_CONSOLE, HideWindow: false},
|
||||||
}
|
}
|
||||||
proc, err := os.StartProcess(args[0], args, attrs)
|
proc, err := os.StartProcess(args[0], args, attrs)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to start getting started shell %w", err)
|
return fmt.Errorf("unable to start getting started shell %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ func InitLogging() {
|
||||||
// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
|
// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
|
||||||
} else {
|
} else {
|
||||||
rotateLogs(AppLogFile)
|
rotateLogs(AppLogFile)
|
||||||
logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error(fmt.Sprintf("failed to create server log %v", err))
|
slog.Error(fmt.Sprintf("failed to create server log %v", err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -5,5 +5,5 @@ package lifecycle
|
||||||
import "log/slog"
|
import "log/slog"
|
||||||
|
|
||||||
func ShowLogs() {
|
func ShowLogs() {
|
||||||
slog.Warn("ShowLogs not yet implemented")
|
slog.Warn("not implemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ func TestRotateLogs(t *testing.T) {
|
||||||
// No log exists
|
// No log exists
|
||||||
rotateLogs(logFile)
|
rotateLogs(logFile)
|
||||||
|
|
||||||
require.NoError(t, os.WriteFile(logFile, []byte("1"), 0644))
|
require.NoError(t, os.WriteFile(logFile, []byte("1"), 0o644))
|
||||||
assert.FileExists(t, logFile)
|
assert.FileExists(t, logFile)
|
||||||
// First rotation
|
// First rotation
|
||||||
rotateLogs(logFile)
|
rotateLogs(logFile)
|
||||||
|
@ -32,7 +32,7 @@ func TestRotateLogs(t *testing.T) {
|
||||||
assert.NoFileExists(t, logFile)
|
assert.NoFileExists(t, logFile)
|
||||||
|
|
||||||
for i := 2; i <= LogRotationCount+1; i++ {
|
for i := 2; i <= LogRotationCount+1; i++ {
|
||||||
require.NoError(t, os.WriteFile(logFile, []byte(strconv.Itoa(i)), 0644))
|
require.NoError(t, os.WriteFile(logFile, []byte(strconv.Itoa(i)), 0o644))
|
||||||
assert.FileExists(t, logFile)
|
assert.FileExists(t, logFile)
|
||||||
rotateLogs(logFile)
|
rotateLogs(logFile)
|
||||||
assert.NoFileExists(t, logFile)
|
assert.NoFileExists(t, logFile)
|
||||||
|
|
|
@ -55,7 +55,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rotateLogs(ServerLogFile)
|
rotateLogs(ServerLogFile)
|
||||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||||
query.Add("os", runtime.GOOS)
|
query.Add("os", runtime.GOOS)
|
||||||
query.Add("arch", runtime.GOARCH)
|
query.Add("arch", runtime.GOARCH)
|
||||||
query.Add("version", version.Version)
|
query.Add("version", version.Version)
|
||||||
query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))
|
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
nonce, err := auth.NewNonce(rand.Reader, 16)
|
nonce, err := auth.NewNonce(rand.Reader, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -4,9 +4,9 @@ package lifecycle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||||
return fmt.Errorf("DoUpgrade not yet implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package lifecycle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
@ -15,7 +16,7 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||||
return fmt.Errorf("failed to lookup downloads: %s", err)
|
return fmt.Errorf("failed to lookup downloads: %s", err)
|
||||||
}
|
}
|
||||||
if len(files) == 0 {
|
if len(files) == 0 {
|
||||||
return fmt.Errorf("no update downloads found")
|
return errors.New("no update downloads found")
|
||||||
} else if len(files) > 1 {
|
} else if len(files) > 1 {
|
||||||
// Shouldn't happen
|
// Shouldn't happen
|
||||||
slog.Warn(fmt.Sprintf("multiple downloads found, using first one %v", files))
|
slog.Warn(fmt.Sprintf("multiple downloads found, using first one %v", files))
|
||||||
|
@ -64,7 +65,7 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO - some details about why it didn't start, or is this a pedantic error case?
|
// TODO - some details about why it didn't start, or is this a pedantic error case?
|
||||||
return fmt.Errorf("installer process did not start")
|
return errors.New("installer process did not start")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO should we linger for a moment and check to make sure it's actually running by checking the pid?
|
// TODO should we linger for a moment and check to make sure it's actually running by checking the pid?
|
||||||
|
|
|
@ -3,11 +3,11 @@
|
||||||
package tray
|
package tray
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/tray/commontray"
|
"github.com/ollama/ollama/app/tray/commontray"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) {
|
func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) {
|
||||||
return nil, fmt.Errorf("NOT IMPLEMENTED YET")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,9 +11,7 @@ import (
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var quitOnce sync.Once
|
||||||
quitOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t *winTray) Run() {
|
func (t *winTray) Run() {
|
||||||
nativeLoop()
|
nativeLoop()
|
||||||
|
|
|
@ -13,8 +13,9 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/tray/commontray"
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/app/tray/commontray"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32
|
// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32
|
||||||
|
@ -414,7 +415,7 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
|
||||||
iconFilePath := filepath.Join(os.TempDir(), "ollama_temp_icon_"+dataHash)
|
iconFilePath := filepath.Join(os.TempDir(), "ollama_temp_icon_"+dataHash)
|
||||||
|
|
||||||
if _, err := os.Stat(iconFilePath); os.IsNotExist(err) {
|
if _, err := os.Stat(iconFilePath); os.IsNotExist(err) {
|
||||||
if err := os.WriteFile(iconFilePath, iconBytes, 0644); err != nil {
|
if err := os.WriteFile(iconFilePath, iconBytes, 0o644); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
@ -78,7 +79,7 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
parts := bytes.Split(publicKey, []byte(" "))
|
parts := bytes.Split(publicKey, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", errors.New("malformed public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
signedData, err := privateKey.Sign(rand.Reader, bts)
|
signedData, err := privateKey.Sign(rand.Reader, bts)
|
||||||
|
|
|
@ -1160,7 +1160,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := startApp(cmd.Context(), client); err != nil {
|
if err := startApp(cmd.Context(), client); err != nil {
|
||||||
return fmt.Errorf("could not connect to ollama app, is it running?")
|
return errors.New("could not connect to ollama app, is it running?")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -604,7 +604,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||||
// Check if the file size exceeds 100MB
|
// Check if the file size exceeds 100MB
|
||||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||||
if info.Size() > maxSize {
|
if info.Size() > maxSize {
|
||||||
return nil, fmt.Errorf("file size exceeds maximum limit (100MB)")
|
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||||
}
|
}
|
||||||
|
|
||||||
buf = make([]byte, info.Size())
|
buf = make([]byte, info.Size())
|
||||||
|
|
|
@ -2,7 +2,7 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -20,7 +20,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !strings.Contains(link, "Ollama.app") {
|
if !strings.Contains(link, "Ollama.app") {
|
||||||
return fmt.Errorf("could not find ollama app")
|
return errors.New("could not find ollama app")
|
||||||
}
|
}
|
||||||
path := strings.Split(link, "Ollama.app")
|
path := strings.Split(link, "Ollama.app")
|
||||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||||
|
|
|
@ -4,11 +4,11 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startApp(ctx context.Context, client *api.Client) error {
|
func startApp(ctx context.Context, client *api.Client) error {
|
||||||
return fmt.Errorf("could not connect to ollama server, run 'ollama serve' to start it")
|
return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||||
// Finally look in the path
|
// Finally look in the path
|
||||||
appExe, err = exec.LookPath(AppName)
|
appExe, err = exec.LookPath(AppName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not locate ollama app")
|
return errors.New("could not locate ollama app")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,9 +5,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llama struct {
|
type llama struct {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -14,8 +15,9 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
|
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
|
||||||
|
@ -99,7 +101,7 @@ func TestConvertFull(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
actual[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
|
expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
|
||||||
|
|
|
@ -10,8 +10,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZipReader struct {
|
type ZipReader struct {
|
||||||
r *zip.Reader
|
r *zip.Reader
|
||||||
p string
|
p string
|
||||||
|
|
||||||
// limit is the maximum size of a file that can be read directly
|
// limit is the maximum size of a file that can be read directly
|
||||||
// from the zip archive. Files larger than this size will be extracted
|
// from the zip archive. Files larger than this size will be extracted
|
||||||
|
|
|
@ -111,8 +111,9 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, b := range u16s {
|
f32s = make([]float32, len(u16s))
|
||||||
f32s = append(f32s, float16.Frombits(b).Float32())
|
for i := range u16s {
|
||||||
|
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||||
}
|
}
|
||||||
|
|
||||||
case "BF16":
|
case "BF16":
|
||||||
|
|
|
@ -3,6 +3,7 @@ package format
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -28,6 +29,6 @@ func HumanNumber(b uint64) string {
|
||||||
case b >= Thousand:
|
case b >= Thousand:
|
||||||
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
|
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%d", b)
|
return strconv.FormatUint(b, 10)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
package gpu
|
package gpu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -95,5 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gpu
|
package gpu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -76,7 +77,7 @@ func (hl *HipLib) Release() {
|
||||||
|
|
||||||
func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||||
if hl.dll == 0 {
|
if hl.dll == 0 {
|
||||||
return 0, 0, fmt.Errorf("dll has been unloaded")
|
return 0, 0, errors.New("dll has been unloaded")
|
||||||
}
|
}
|
||||||
var version int
|
var version int
|
||||||
status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
|
status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
|
||||||
|
@ -110,7 +111,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
|
||||||
|
|
||||||
func (hl *HipLib) HipSetDevice(device int) error {
|
func (hl *HipLib) HipSetDevice(device int) error {
|
||||||
if hl.dll == 0 {
|
if hl.dll == 0 {
|
||||||
return fmt.Errorf("dll has been unloaded")
|
return errors.New("dll has been unloaded")
|
||||||
}
|
}
|
||||||
status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
|
status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
|
||||||
if status != hipSuccess {
|
if status != hipSuccess {
|
||||||
|
@ -121,7 +122,7 @@ func (hl *HipLib) HipSetDevice(device int) error {
|
||||||
|
|
||||||
func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
|
func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
|
||||||
if hl.dll == 0 {
|
if hl.dll == 0 {
|
||||||
return nil, fmt.Errorf("dll has been unloaded")
|
return nil, errors.New("dll has been unloaded")
|
||||||
}
|
}
|
||||||
var props hipDevicePropMinimal
|
var props hipDevicePropMinimal
|
||||||
status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
|
status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
|
||||||
|
@ -134,7 +135,7 @@ func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, err
|
||||||
// free, total, err
|
// free, total, err
|
||||||
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
|
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
|
||||||
if hl.dll == 0 {
|
if hl.dll == 0 {
|
||||||
return 0, 0, fmt.Errorf("dll has been unloaded")
|
return 0, 0, errors.New("dll has been unloaded")
|
||||||
}
|
}
|
||||||
var totalMemory uint64
|
var totalMemory uint64
|
||||||
var freeMemory uint64
|
var freeMemory uint64
|
||||||
|
|
|
@ -393,7 +393,7 @@ func AMDValidateLibDir() (string, error) {
|
||||||
|
|
||||||
// If we still haven't found a usable rocm, the user will have to install it on their own
|
// If we still haven't found a usable rocm, the user will have to install it on their own
|
||||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
|
slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
|
||||||
func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||||
|
|
|
@ -2,7 +2,7 @@ package gpu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -85,7 +85,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||||
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
||||||
gfx := string(props.GcnArchName[:n])
|
gfx := string(props.GcnArchName[:n])
|
||||||
slog.Debug("hip device", "id", i, "name", name, "gfx", gfx)
|
slog.Debug("hip device", "id", i, "name", name, "gfx", gfx)
|
||||||
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
// slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
||||||
// TODO Why isn't props.iGPU accurate!?
|
// TODO Why isn't props.iGPU accurate!?
|
||||||
if strings.EqualFold(name, iGPUName) {
|
if strings.EqualFold(name, iGPUName) {
|
||||||
slog.Info("unsupported Radeon iGPU detected skipping", "id", i, "name", name, "gfx", gfx)
|
slog.Info("unsupported Radeon iGPU detected skipping", "id", i, "name", name, "gfx", gfx)
|
||||||
|
@ -161,7 +161,7 @@ func AMDValidateLibDir() (string, error) {
|
||||||
|
|
||||||
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
||||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", errors.New("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||||
|
|
|
@ -42,7 +42,7 @@ func PayloadsDir() (string, error) {
|
||||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = os.MkdirAll(tmpDir, 0755)
|
err = os.MkdirAll(tmpDir, 0o755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to generate tmp dir %s: %w", tmpDir, err)
|
return "", fmt.Errorf("failed to generate tmp dir %s: %w", tmpDir, err)
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ func PayloadsDir() (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if _, err := pidFile.Write([]byte(fmt.Sprint(os.Getpid()))); err != nil {
|
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
19
gpu/gpu.go
19
gpu/gpu.go
|
@ -7,9 +7,9 @@ package gpu
|
||||||
#cgo windows LDFLAGS: -lpthread
|
#cgo windows LDFLAGS: -lpthread
|
||||||
|
|
||||||
#include "gpu_info.h"
|
#include "gpu_info.h"
|
||||||
|
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
@ -70,7 +70,6 @@ var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
|
|
||||||
// Note: gpuMutex must already be held
|
// Note: gpuMutex must already be held
|
||||||
func initCudaHandles() *cudaHandles {
|
func initCudaHandles() *cudaHandles {
|
||||||
|
|
||||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||||
|
|
||||||
cHandles := &cudaHandles{}
|
cHandles := &cudaHandles{}
|
||||||
|
@ -211,14 +210,16 @@ func GetGPUInfo() GpuInfoList {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("error looking up system memory", "error", err)
|
slog.Warn("error looking up system memory", "error", err)
|
||||||
}
|
}
|
||||||
cpus = []CPUInfo{CPUInfo{
|
cpus = []CPUInfo{
|
||||||
GpuInfo: GpuInfo{
|
{
|
||||||
memInfo: mem,
|
GpuInfo: GpuInfo{
|
||||||
Library: "cpu",
|
memInfo: mem,
|
||||||
Variant: cpuCapability,
|
Library: "cpu",
|
||||||
ID: "0",
|
Variant: cpuCapability,
|
||||||
|
ID: "0",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
}
|
||||||
|
|
||||||
// Fallback to CPU mode if we're lacking required vector extensions on x86
|
// Fallback to CPU mode if we're lacking required vector extensions on x86
|
||||||
if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
|
if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
|
||||||
|
|
|
@ -8,6 +8,7 @@ package gpu
|
||||||
#include "gpu_info_darwin.h"
|
#include "gpu_info_darwin.h"
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
|
|
@ -67,4 +67,4 @@ void cpu_check_ram(mem_info_t *resp);
|
||||||
#include "gpu_info_oneapi.h"
|
#include "gpu_info_oneapi.h"
|
||||||
|
|
||||||
#endif // __GPU_INFO_H__
|
#endif // __GPU_INFO_H__
|
||||||
#endif // __APPLE__
|
#endif // __APPLE__
|
||||||
|
|
|
@ -43,10 +43,12 @@ var OneapiGlobs = []string{
|
||||||
"/usr/lib*/libze_intel_gpu.so*",
|
"/usr/lib*/libze_intel_gpu.so*",
|
||||||
}
|
}
|
||||||
|
|
||||||
var CudartMgmtName = "libcudart.so*"
|
var (
|
||||||
var NvcudaMgmtName = "libcuda.so*"
|
CudartMgmtName = "libcudart.so*"
|
||||||
var NvmlMgmtName = "" // not currently wired on linux
|
NvcudaMgmtName = "libcuda.so*"
|
||||||
var OneapiMgmtName = "libze_intel_gpu.so"
|
NvmlMgmtName = "" // not currently wired on linux
|
||||||
|
OneapiMgmtName = "libze_intel_gpu.so"
|
||||||
|
)
|
||||||
|
|
||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
var mem memInfo
|
var mem memInfo
|
||||||
|
|
|
@ -40,10 +40,12 @@ var OneapiGlobs = []string{
|
||||||
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
||||||
}
|
}
|
||||||
|
|
||||||
var CudartMgmtName = "cudart64_*.dll"
|
var (
|
||||||
var NvcudaMgmtName = "nvcuda.dll"
|
CudartMgmtName = "cudart64_*.dll"
|
||||||
var NvmlMgmtName = "nvml.dll"
|
NvcudaMgmtName = "nvcuda.dll"
|
||||||
var OneapiMgmtName = "ze_intel_gpu64.dll"
|
NvmlMgmtName = "nvml.dll"
|
||||||
|
OneapiMgmtName = "ze_intel_gpu64.dll"
|
||||||
|
)
|
||||||
|
|
||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
|
memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
|
||||||
|
|
|
@ -162,7 +162,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
// fmt.Print(".")
|
// fmt.Print(".")
|
||||||
if !stallTimer.Reset(stallDuration) {
|
if !stallTimer.Reset(stallDuration) {
|
||||||
return fmt.Errorf("stall was detected, aborting status reporting")
|
return errors.New("stall was detected, aborting status reporting")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -180,7 +180,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
return fmt.Errorf("download stalled")
|
return errors.New("download stalled")
|
||||||
case <-done:
|
case <-done:
|
||||||
return pullError
|
return pullError
|
||||||
}
|
}
|
||||||
|
@ -243,7 +243,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||||
// fmt.Print(".")
|
// fmt.Print(".")
|
||||||
buf.Write([]byte(response.Response))
|
buf.Write([]byte(response.Response))
|
||||||
if !stallTimer.Reset(streamTimeout) {
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
return errors.New("stall was detected while streaming response, aborting")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,9 @@ package llm
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "llama.h"
|
// #include "llama.h"
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
||||||
params.ftype = ftype.Value()
|
params.ftype = ftype.Value()
|
||||||
|
|
||||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||||
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -6,10 +6,11 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/gpu"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEstimateGPULayers(t *testing.T) {
|
func TestEstimateGPULayers(t *testing.T) {
|
||||||
|
|
|
@ -184,15 +184,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
|
|
||||||
params := []string{
|
params := []string{
|
||||||
"--model", model,
|
"--model", model,
|
||||||
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
||||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||||
"--embedding",
|
"--embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
params = append(params, "--log-disable")
|
params = append(params, "--log-disable")
|
||||||
|
|
||||||
if opts.NumGPU >= 0 {
|
if opts.NumGPU >= 0 {
|
||||||
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
|
||||||
}
|
}
|
||||||
|
|
||||||
if envconfig.Debug() {
|
if envconfig.Debug() {
|
||||||
|
@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.MainGPU > 0 {
|
if opts.MainGPU > 0 {
|
||||||
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
|
params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(adapters) > 0 {
|
if len(adapters) > 0 {
|
||||||
|
@ -214,7 +214,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.NumThread > 0 {
|
if opts.NumThread > 0 {
|
||||||
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
|
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.F16KV {
|
if !opts.F16KV {
|
||||||
|
@ -260,7 +260,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
params = append(params, "--parallel", strconv.Itoa(numParallel))
|
||||||
|
|
||||||
if estimate.TensorSplit != "" {
|
if estimate.TensorSplit != "" {
|
||||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||||
|
@ -425,7 +425,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||||
}
|
}
|
||||||
s.done <- fmt.Errorf(s.status.LastErrMsg)
|
s.done <- errors.New(s.status.LastErrMsg)
|
||||||
} else {
|
} else {
|
||||||
s.done <- err
|
s.done <- err
|
||||||
}
|
}
|
||||||
|
|
3
main.go
3
main.go
|
@ -3,8 +3,9 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/ollama/ollama/cmd"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
@ -367,24 +369,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
switch data["type"] {
|
switch data["type"] {
|
||||||
case "text":
|
case "text":
|
||||||
text, ok := data["text"].(string)
|
text, ok := data["text"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
if url, ok = urlMap["url"].(string); !ok {
|
if url, ok = urlMap["url"].(string); !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if url, ok = data["image_url"].(string); !ok {
|
if url, ok = data["image_url"].(string); !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -400,17 +402,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
return nil, fmt.Errorf("invalid image input")
|
return nil, errors.New("invalid image input")
|
||||||
}
|
}
|
||||||
|
|
||||||
img, err := base64.StdEncoding.DecodeString(url)
|
img, err := base64.StdEncoding.DecodeString(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, errors.New("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -423,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid tool call arguments")
|
return nil, errors.New("invalid tool call arguments")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
||||||
|
@ -737,14 +739,12 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
var embedResponse api.EmbedResponse
|
var embedResponse api.EmbedResponse
|
||||||
err := json.Unmarshal(data, &embedResponse)
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,13 +12,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const prefix = `data:image/jpeg;base64,`
|
const (
|
||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
prefix = `data:image/jpeg;base64,`
|
||||||
const imageURL = prefix + image
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
imageURL = prefix + image
|
||||||
|
)
|
||||||
|
|
||||||
func prepareRequest(req *http.Request, body any) {
|
func prepareRequest(req *http.Request, body any) {
|
||||||
bodyBytes, _ := json.Marshal(body)
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
|
@ -82,7 +82,7 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileFrom(t *testing.T) {
|
func TestParseFileFrom(t *testing.T) {
|
||||||
var cases = []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
expected []Command
|
expected []Command
|
||||||
err error
|
err error
|
||||||
|
@ -185,7 +185,7 @@ BADCOMMAND param1 value1
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileMessages(t *testing.T) {
|
func TestParseFileMessages(t *testing.T) {
|
||||||
var cases = []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
expected []Command
|
expected []Command
|
||||||
err error
|
err error
|
||||||
|
@ -276,7 +276,7 @@ MESSAGE system`,
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileQuoted(t *testing.T) {
|
func TestParseFileQuoted(t *testing.T) {
|
||||||
var cases = []struct {
|
cases := []struct {
|
||||||
multiline string
|
multiline string
|
||||||
expected []Command
|
expected []Command
|
||||||
err error
|
err error
|
||||||
|
@ -430,7 +430,7 @@ TEMPLATE """
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileParameters(t *testing.T) {
|
func TestParseFileParameters(t *testing.T) {
|
||||||
var cases = map[string]struct {
|
cases := map[string]struct {
|
||||||
name, value string
|
name, value string
|
||||||
}{
|
}{
|
||||||
"numa true": {"numa", "true"},
|
"numa true": {"numa", "true"},
|
||||||
|
@ -491,7 +491,7 @@ func TestParseFileParameters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileComments(t *testing.T) {
|
func TestParseFileComments(t *testing.T) {
|
||||||
var cases = []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
expected []Command
|
expected []Command
|
||||||
}{
|
}{
|
||||||
|
@ -516,7 +516,7 @@ FROM foo
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileFormatParseFile(t *testing.T) {
|
func TestParseFileFormatParseFile(t *testing.T) {
|
||||||
var cases = []string{
|
cases := []string{
|
||||||
`
|
`
|
||||||
FROM foo
|
FROM foo
|
||||||
ADAPTER adapter1
|
ADAPTER adapter1
|
||||||
|
|
|
@ -6,8 +6,9 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bar struct {
|
type Bar struct {
|
||||||
|
|
|
@ -13,7 +13,7 @@ type Buffer struct {
|
||||||
DisplayPos int
|
DisplayPos int
|
||||||
Pos int
|
Pos int
|
||||||
Buf *arraylist.List
|
Buf *arraylist.List
|
||||||
//LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
|
// LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
|
||||||
LineHasSpace *arraylist.List
|
LineHasSpace *arraylist.List
|
||||||
Prompt *Prompt
|
Prompt *Prompt
|
||||||
LineWidth int
|
LineWidth int
|
||||||
|
@ -56,7 +56,7 @@ func (b *Buffer) GetLineSpacing(line int) bool {
|
||||||
|
|
||||||
func (b *Buffer) MoveLeft() {
|
func (b *Buffer) MoveLeft() {
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
//asserts that we retrieve a rune
|
// asserts that we retrieve a rune
|
||||||
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
||||||
if r, ok := e.(rune); ok {
|
if r, ok := e.(rune); ok {
|
||||||
rLength := runewidth.RuneWidth(r)
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
|
|
@ -4,9 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var ErrInterrupt = errors.New("Interrupt")
|
||||||
ErrInterrupt = errors.New("Interrupt")
|
|
||||||
)
|
|
||||||
|
|
||||||
type InterruptError struct {
|
type InterruptError struct {
|
||||||
Line []rune
|
Line []rune
|
||||||
|
|
|
@ -7,8 +7,10 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tcgets = 0x5401
|
const (
|
||||||
const tcsets = 0x5402
|
tcgets = 0x5401
|
||||||
|
tcsets = 0x5402
|
||||||
|
)
|
||||||
|
|
||||||
func getTermios(fd uintptr) (*Termios, error) {
|
func getTermios(fd uintptr) (*Termios, error) {
|
||||||
termios := new(Termios)
|
termios := new(Termios)
|
||||||
|
|
|
@ -28,8 +28,10 @@ import (
|
||||||
|
|
||||||
const maxRetries = 6
|
const maxRetries = 6
|
||||||
|
|
||||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
var (
|
||||||
var errPartStalled = errors.New("part stalled")
|
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||||
|
errPartStalled = errors.New("part stalled")
|
||||||
|
)
|
||||||
|
|
||||||
var blobDownloadManager sync.Map
|
var blobDownloadManager sync.Map
|
||||||
|
|
||||||
|
|
|
@ -828,7 +828,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return fmt.Errorf("insecure protocol http")
|
return errors.New("insecure protocol http")
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, _, err := GetManifest(mp)
|
manifest, _, err := GetManifest(mp)
|
||||||
|
@ -895,7 +895,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
}
|
}
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return fmt.Errorf("insecure protocol http")
|
return errors.New("insecure protocol http")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||||
|
@ -1010,7 +1010,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||||
|
|
||||||
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
||||||
func getTokenSubject(token string) string {
|
func getTokenSubject(token string) string {
|
||||||
|
|
|
@ -2,9 +2,9 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
@ -88,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
|
|
||||||
m.filepath = p
|
m.filepath = p
|
||||||
m.fi = fi
|
m.fi = fi
|
||||||
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ func createManifest(t *testing.T, path, name string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
p := filepath.Join(path, "manifests", name)
|
p := filepath.Join(path, "manifests", name)
|
||||||
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
|
@ -55,8 +55,10 @@ func init() {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var (
|
||||||
var errBadTemplate = errors.New("template error")
|
errRequired = errors.New("is required")
|
||||||
|
errBadTemplate = errors.New("template error")
|
||||||
|
)
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
|
@ -369,7 +371,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
input[i] = s
|
input[i] = s
|
||||||
}
|
}
|
||||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("embedding generation failed", "error", err)
|
slog.Error("embedding generation failed", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
@ -430,7 +431,6 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
@ -556,7 +556,7 @@ func checkNameExists(name model.Name) error {
|
||||||
|
|
||||||
for n := range names {
|
for n := range names {
|
||||||
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
||||||
return fmt.Errorf("a model with that name already exists")
|
return errors.New("a model with that name already exists")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -729,7 +729,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
|
|
||||||
n := model.ParseName(req.Model)
|
n := model.ParseName(req.Model)
|
||||||
if !n.IsValid() {
|
if !n.IsValid() {
|
||||||
return nil, fmt.Errorf("invalid model name")
|
return nil, errors.New("invalid model name")
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, err := ParseNamedManifest(n)
|
manifest, err := ParseNamedManifest(n)
|
||||||
|
@ -993,7 +993,7 @@ func allowedHost(host string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var tlds = []string{
|
tlds := []string{
|
||||||
"localhost",
|
"localhost",
|
||||||
"local",
|
"local",
|
||||||
"internal",
|
"internal",
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
@ -489,7 +490,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
@ -501,7 +502,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
@ -513,7 +514,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -333,7 +333,6 @@ func Test_Routes(t *testing.T) {
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||||
}
|
}
|
||||||
_, err := io.ReadAll(resp.Body)
|
_, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ var defaultModelsPerGPU = 3
|
||||||
// we'll back off down to 1 to try to get it to fit
|
// we'll back off down to 1 to try to get it to fit
|
||||||
var defaultParallel = 4
|
var defaultParallel = 4
|
||||||
|
|
||||||
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
|
||||||
|
|
||||||
func InitScheduler(ctx context.Context) *Scheduler {
|
func InitScheduler(ctx context.Context) *Scheduler {
|
||||||
maxQueue := envconfig.MaxQueue()
|
maxQueue := envconfig.MaxQueue()
|
||||||
|
|
|
@ -3,23 +3,25 @@ package server
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func TestMain(m *testing.M) {
|
||||||
os.Setenv("OLLAMA_DEBUG", "1")
|
os.Setenv("OLLAMA_DEBUG", "1")
|
||||||
lifecycle.InitLogging()
|
lifecycle.InitLogging()
|
||||||
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitScheduler(t *testing.T) {
|
func TestInitScheduler(t *testing.T) {
|
||||||
|
@ -46,7 +48,7 @@ func TestLoad(t *testing.T) {
|
||||||
}
|
}
|
||||||
// Fail to load model first
|
// Fail to load model first
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return nil, fmt.Errorf("something failed to load model blah")
|
return nil, errors.New("something failed to load model blah")
|
||||||
}
|
}
|
||||||
gpus := gpu.GpuInfoList{}
|
gpus := gpu.GpuInfoList{}
|
||||||
s.load(req, ggml, gpus, 0)
|
s.load(req, ggml, gpus, 0)
|
||||||
|
@ -75,7 +77,7 @@ func TestLoad(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
req.model.ModelPath = "dummy_model_path"
|
req.model.ModelPath = "dummy_model_path"
|
||||||
server.waitResp = fmt.Errorf("wait failure")
|
server.waitResp = errors.New("wait failure")
|
||||||
s.load(req, ggml, gpus, 0)
|
s.load(req, ggml, gpus, 0)
|
||||||
select {
|
select {
|
||||||
case err := <-req.errCh:
|
case err := <-req.errCh:
|
||||||
|
@ -600,7 +602,7 @@ func TestNeedsReload(t *testing.T) {
|
||||||
resp = runner.needsReload(ctx, req)
|
resp = runner.needsReload(ctx, req)
|
||||||
require.True(t, resp)
|
require.True(t, resp)
|
||||||
req.opts.NumBatch = runner.Options.NumBatch
|
req.opts.NumBatch = runner.Options.NumBatch
|
||||||
llm.pingResp = fmt.Errorf("foo")
|
llm.pingResp = errors.New("foo")
|
||||||
resp = runner.needsReload(ctx, req)
|
resp = runner.needsReload(ctx, req)
|
||||||
require.True(t, resp)
|
require.True(t, resp)
|
||||||
llm.pingResp = nil
|
llm.pingResp = nil
|
||||||
|
@ -724,15 +726,19 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
||||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embedResp, s.embedRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
return s.tokenizeResp, s.tokenizeRespErr
|
return s.tokenizeResp, s.tokenizeRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
return s.detokenizeResp, s.detonekizeRespErr
|
return s.detokenizeResp, s.detonekizeRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Close() error {
|
func (s *mockLlm) Close() error {
|
||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
|
|
|
@ -12,13 +12,15 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var blobUploadManager sync.Map
|
var blobUploadManager sync.Map
|
||||||
|
@ -212,7 +214,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
|
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
headers.Set("Content-Length", strconv.FormatInt(part.Size, 10))
|
||||||
|
|
||||||
if method == http.MethodPatch {
|
if method == http.MethodPatch {
|
||||||
headers.Set("X-Redirect-Uploads", "1")
|
headers.Set("X-Redirect-Uploads", "1")
|
||||||
|
|
|
@ -15,8 +15,9 @@ import (
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
"github.com/agnivade/levenshtein"
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed index.json
|
//go:embed index.json
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,8 +6,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const UnknownOllamaKeyErrMsg = "unknown ollama key"
|
const (
|
||||||
const InvalidModelNameErrMsg = "invalid model name"
|
UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||||
|
InvalidModelNameErrMsg = "invalid model name"
|
||||||
|
)
|
||||||
|
|
||||||
// TODO: This should have a structured response from the API
|
// TODO: This should have a structured response from the API
|
||||||
type UnknownOllamaKey struct {
|
type UnknownOllamaKey struct {
|
||||||
|
|
|
@ -258,7 +258,7 @@ func (n Name) IsValid() bool {
|
||||||
// IsFullyQualified returns true if all parts of the name are present and
|
// IsFullyQualified returns true if all parts of the name are present and
|
||||||
// valid without the digest.
|
// valid without the digest.
|
||||||
func (n Name) IsFullyQualified() bool {
|
func (n Name) IsFullyQualified() bool {
|
||||||
var parts = []string{
|
parts := []string{
|
||||||
n.Host,
|
n.Host,
|
||||||
n.Namespace,
|
n.Namespace,
|
||||||
n.Model,
|
n.Model,
|
||||||
|
|
Loading…
Reference in a new issue