This commit is contained in:
Michael Yang 2024-05-21 22:21:04 -07:00
parent c895a7d13f
commit e40145a39d
31 changed files with 127 additions and 136 deletions

View file

@ -12,8 +12,14 @@ linters:
# FIXME: for some reason this errors on windows # FIXME: for some reason this errors on windows
# - gofmt # - gofmt
# - goimports # - goimports
- intrange
- misspell - misspell
- nilerr - nilerr
- nolintlint - nolintlint
- nosprintfhostport - nosprintfhostport
- testifylint
- unconvert
- unused - unused
- usestdlibvars
- wastedassign
- whitespace

View file

@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
}, },
{ {
"positive duration", "positive duration",
time.Duration(42 * time.Second), 42 * time.Second,
time.Duration(42 * time.Second), 42 * time.Second,
}, },
{ {
"another positive duration", "another positive duration",
time.Duration(42 * time.Minute), 42 * time.Minute,
time.Duration(42 * time.Minute), 42 * time.Minute,
}, },
{ {
"zero duration", "zero duration",

View file

@ -69,7 +69,6 @@ func init() {
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
} }
} }
} else if runtime.GOOS == "darwin" { } else if runtime.GOOS == "darwin" {
// TODO // TODO
AppName += ".app" AppName += ".app"

View file

@ -15,7 +15,7 @@ import (
) )
func getCLIFullPath(command string) string { func getCLIFullPath(command string) string {
cmdPath := "" var cmdPath string
appExe, err := os.Executable() appExe, err := os.Executable()
if err == nil { if err == nil {
cmdPath = filepath.Join(filepath.Dir(appExe), command) cmdPath = filepath.Join(filepath.Dir(appExe), command)
@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
if err != nil { if err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err) return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
} }
if err := os.MkdirAll(logDir, 0o755); err != nil { if err := os.MkdirAll(logDir, 0o755); err != nil {

View file

@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == 204 { if resp.StatusCode == http.StatusNoContent {
slog.Debug("check update response 204 (current version is up to date)") slog.Debug("check update response 204 (current version is up to date)")
return false, updateResp return false, updateResp
} }
@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
slog.Warn(fmt.Sprintf("failed to read body response: %s", err)) slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
} }
if resp.StatusCode != 200 { if resp.StatusCode != http.StatusOK {
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body))) slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
return false, updateResp return false, updateResp
} }
@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
if err != nil { if err != nil {
return fmt.Errorf("error checking update: %w", err) return fmt.Errorf("error checking update: %w", err)
} }
if resp.StatusCode != 200 { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
} }
resp.Body.Close() resp.Body.Close()

View file

@ -29,7 +29,6 @@ func GetID() string {
initStore() initStore()
} }
return store.ID return store.ID
} }
func GetFirstTimeRun() bool { func GetFirstTimeRun() bool {

View file

@ -746,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
if wordWrap && termWidth >= 10 { if wordWrap && termWidth >= 10 {
for _, ch := range content { for _, ch := range content {
if state.lineLength+1 > termWidth-5 { if state.lineLength+1 > termWidth-5 {
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 { if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch) fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = "" state.wordBuffer = ""
@ -1044,7 +1043,6 @@ func waitForServer(ctx context.Context, client *api.Client) error {
} }
} }
} }
} }
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {

View file

@ -6,6 +6,7 @@ import (
"text/template" "text/template"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err := template.New("").Parse(expectedModelfile) tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err) require.NoError(t, err)
var buf bytes.Buffer var buf bytes.Buffer
err = tmpl.Execute(&buf, opts) err = tmpl.Execute(&buf, opts)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, buf.String(), mf) assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark" opts.ParentModel = "horseshark"
@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err = template.New("").Parse(expectedModelfile) tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err) require.NoError(t, err)
var parentBuf bytes.Buffer var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts) err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, parentBuf.String(), mf) assert.Equal(t, parentBuf.String(), mf)
} }

View file

@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
if params.VocabSize > len(v.Tokens) { if params.VocabSize > len(v.Tokens) {
missingTokens := params.VocabSize - len(v.Tokens) missingTokens := params.VocabSize - len(v.Tokens)
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens)) slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := 0; cnt < missingTokens; cnt++ { for cnt := range missingTokens {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1)) v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
v.Scores = append(v.Scores, -1) v.Scores = append(v.Scores, -1)
v.Types = append(v.Types, tokenTypeUserDefined) v.Types = append(v.Types, tokenTypeUserDefined)

View file

@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
} }
return tensors, nil return tensors, nil
} }
func getAltParams(dirpath string) (*Params, error) { func getAltParams(dirpath string) (*Params, error) {

View file

@ -5,7 +5,6 @@ import (
) )
func TestHumanNumber(t *testing.T) { func TestHumanNumber(t *testing.T) {
type testCase struct { type testCase struct {
input uint64 input uint64
expected string expected string

View file

@ -80,7 +80,7 @@ func cleanupTmpDirs() {
if err == nil { if err == nil {
pid, err := strconv.Atoi(string(raw)) pid, err := strconv.Atoi(string(raw))
if err == nil { if err == nil {
if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
// Another running ollama, ignore this tmpdir // Another running ollama, ignore this tmpdir
continue continue
} }

View file

@ -5,11 +5,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestBasicGetGPUInfo(t *testing.T) { func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo() info := GetGPUInfo()
assert.Greater(t, len(info), 0) assert.NotEmpty(t, len(info))
assert.Contains(t, "cuda rocm cpu metal", info[0].Library) assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" { if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0)) assert.Greater(t, info[0].TotalMemory, uint64(0))
@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
func TestCPUMemInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) {
info, err := GetCPUMem() info, err := GetCPUMem()
assert.NoError(t, err) require.NoError(t, err)
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":
t.Skip("CPU memory not populated on darwin") t.Skip("CPU memory not populated on darwin")

View file

@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
dims := 0 var dims int
for cnt := 0; cnt < len(tensor.Shape); cnt++ { for cnt := range len(tensor.Shape) {
if tensor.Shape[cnt] > 0 { if tensor.Shape[cnt] > 0 {
dims++ dims++
} }
@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
for i := 0; i < dims; i++ { for i := range dims {
if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil { if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
return err return err
} }
} }

View file

@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
} }
var layerCount int var layerCount int
for i := 0; i < int(ggml.KV().BlockCount()); i++ { for i := range int(ggml.KV().BlockCount()) {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
memoryLayer := blk.size() memoryLayer := blk.size()

View file

@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var systemMemory uint64 var systemMemory uint64
gpuCount := len(gpus) gpuCount := len(gpus)
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
cpuRunner = serverForCpu() cpuRunner = serverForCpu()
@ -233,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ { for i := range len(servers) {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
if dir == "" { if dir == "" {
// Shouldn't happen // Shouldn't happen
@ -316,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv() visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path and visible devices variable with our adjusted version // Update or add the path and visible devices variable with our adjusted version

View file

@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) {
d, err := json.Marshal(toChunk(w.id, chatResponse)) d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil { if err != nil {
return 0, err return 0, err
} }
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")

View file

@ -10,6 +10,7 @@ import (
"unicode/utf16" "unicode/utf16"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestParseFileFile(t *testing.T) { func TestParseFileFile(t *testing.T) {
@ -25,7 +26,7 @@ TEMPLATE template1
reader := strings.NewReader(input) reader := strings.NewReader(input)
modelfile, err := ParseFile(reader) modelfile, err := ParseFile(reader)
assert.NoError(t, err) require.NoError(t, err)
expectedCommands := []Command{ expectedCommands := []Command{
{Name: "model", Args: "model1"}, {Name: "model", Args: "model1"},
@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err) require.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@ -105,7 +106,7 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := ParseFile(reader) _, err := ParseFile(reader)
assert.ErrorIs(t, err, io.ErrUnexpectedEOF) require.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func TestParseFileBadCommand(t *testing.T) { func TestParseFileBadCommand(t *testing.T) {
@ -114,8 +115,7 @@ FROM foo
BADCOMMAND param1 value1 BADCOMMAND param1 value1
` `
_, err := ParseFile(strings.NewReader(input)) _, err := ParseFile(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand) require.ErrorIs(t, err, errInvalidCommand)
} }
func TestParseFileMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
@ -201,7 +201,7 @@ MESSAGE system`,
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err) require.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@ -355,7 +355,7 @@ TEMPLATE """
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.multiline)) modelfile, err := ParseFile(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err) require.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k) fmt.Fprintln(&b, "PARAMETER", k)
modelfile, err := ParseFile(&b) modelfile, err := ParseFile(&b)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []Command{ assert.Equal(t, []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
@ -442,7 +442,7 @@ FROM foo
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
}) })
} }
@ -501,15 +501,14 @@ SYSTEM ""
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c)) modelfile, err := ParseFile(strings.NewReader(c))
assert.NoError(t, err) require.NoError(t, err)
modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, modelfile, modelfile2) assert.Equal(t, modelfile, modelfile2)
}) })
} }
} }
func TestParseFileUTF16ParseFile(t *testing.T) { func TestParseFileUTF16ParseFile(t *testing.T) {
@ -522,10 +521,10 @@ SYSTEM You are a utf16 file.
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...)) utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File) err := binary.Write(buf, binary.LittleEndian, utf16File)
assert.NoError(t, err) require.NoError(t, err)
actual, err := ParseFile(buf) actual, err := ParseFile(buf)
assert.NoError(t, err) require.NoError(t, err)
expected := []Command{ expected := []Command{
{Name: "model", Args: "bob"}, {Name: "model", Args: "bob"},
@ -539,9 +538,9 @@ SYSTEM You are a utf16 file.
// simulate a utf16 be file // simulate a utf16 be file
buf = new(bytes.Buffer) buf = new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, utf16File) err = binary.Write(buf, binary.BigEndian, utf16File)
assert.NoError(t, err) require.NoError(t, err)
actual, err = ParseFile(buf) actual, err = ParseFile(buf)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual.Commands) assert.Equal(t, expected, actual.Commands)
} }

View file

@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
stopped := p.stop() stopped := p.stop()
if stopped { if stopped {
// clear all progress lines // clear all progress lines
for i := 0; i < p.pos; i++ { for i := range p.pos {
if i > 0 { if i > 0 {
fmt.Fprint(p.w, "\033[A") fmt.Fprint(p.w, "\033[A")
} }
@ -85,7 +85,7 @@ func (p *Progress) render() {
defer fmt.Fprint(p.w, "\033[?25h") defer fmt.Fprint(p.w, "\033[?25h")
// clear already rendered progress lines // clear already rendered progress lines
for i := 0; i < p.pos; i++ { for i := range p.pos {
if i > 0 { if i > 0 {
fmt.Fprint(p.w, "\033[A") fmt.Fprint(p.w, "\033[A")
} }

View file

@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() {
if b.Pos > 0 { if b.Pos > 0 {
currLine := b.DisplayPos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
if currLine > 0 { if currLine > 0 {
for cnt := 0; cnt < currLine; cnt++ { for range currLine {
fmt.Print(CursorUp) fmt.Print(CursorUp)
} }
} }
@ -169,7 +169,7 @@ func (b *Buffer) MoveToEnd() {
currLine := b.DisplayPos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
totalLines := b.DisplaySize() / b.LineWidth totalLines := b.DisplaySize() / b.LineWidth
if currLine < totalLines { if currLine < totalLines {
for cnt := 0; cnt < totalLines-currLine; cnt++ { for range totalLines - currLine {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
remainder := b.DisplaySize() % b.LineWidth remainder := b.DisplaySize() % b.LineWidth
@ -451,7 +451,7 @@ func (b *Buffer) DeleteBefore() {
func (b *Buffer) DeleteRemaining() { func (b *Buffer) DeleteRemaining() {
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() { if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
charsToDel := b.Buf.Size() - b.Pos charsToDel := b.Buf.Size() - b.Pos
for cnt := 0; cnt < charsToDel; cnt++ { for range charsToDel {
b.Delete() b.Delete()
} }
} }
@ -495,7 +495,7 @@ func (b *Buffer) ClearScreen() {
if currPos > 0 { if currPos > 0 {
targetLine := currPos / b.LineWidth targetLine := currPos / b.LineWidth
if targetLine > 0 { if targetLine > 0 {
for cnt := 0; cnt < targetLine; cnt++ { for range targetLine {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
} }

View file

@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
func (h *History) Compact() { func (h *History) Compact() {
s := h.Buf.Size() s := h.Buf.Size()
if s > h.Limit { if s > h.Limit {
for cnt := 0; cnt < s-h.Limit; cnt++ { for range s - h.Limit {
h.Buf.Remove(0) h.Buf.Remove(0)
} }
} }
@ -139,7 +139,7 @@ func (h *History) Save() error {
defer f.Close() defer f.Close()
buf := bufio.NewWriter(f) buf := bufio.NewWriter(f)
for cnt := 0; cnt < h.Size(); cnt++ { for cnt := range h.Size() {
v, _ := h.Buf.Get(cnt) v, _ := h.Buf.Get(cnt)
line, _ := v.([]rune) line, _ := v.([]rune)
if _, err := buf.WriteString(string(line) + "\n"); err != nil { if _, err := buf.WriteString(string(line) + "\n"); err != nil {

View file

@ -63,7 +63,7 @@ func New(prompt Prompt) (*Instance, error) {
func (i *Instance) Readline() (string, error) { func (i *Instance) Readline() (string, error) {
if !i.Terminal.rawmode { if !i.Terminal.rawmode {
fd := int(syscall.Stdin) fd := syscall.Stdin
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)
if err != nil { if err != nil {
return "", err return "", err
@ -80,7 +80,7 @@ func (i *Instance) Readline() (string, error) {
fmt.Print(prompt) fmt.Print(prompt)
defer func() { defer func() {
fd := int(syscall.Stdin) fd := syscall.Stdin
//nolint:errcheck //nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios) UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false i.Terminal.rawmode = false
@ -136,7 +136,7 @@ func (i *Instance) Readline() (string, error) {
buf.MoveRight() buf.MoveRight()
case CharBracketedPaste: case CharBracketedPaste:
var code string var code string
for cnt := 0; cnt < 3; cnt++ { for range 3 {
r, err = i.Terminal.Read() r, err = i.Terminal.Read()
if err != nil { if err != nil {
return "", io.EOF return "", io.EOF
@ -198,7 +198,7 @@ func (i *Instance) Readline() (string, error) {
buf.Remove() buf.Remove()
case CharTab: case CharTab:
// todo: convert back to real tabs // todo: convert back to real tabs
for cnt := 0; cnt < 8; cnt++ { for range 8 {
buf.Add(' ') buf.Add(' ')
} }
case CharDelete: case CharDelete:
@ -216,7 +216,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlW: case CharCtrlW:
buf.DeleteWord() buf.DeleteWord()
case CharCtrlZ: case CharCtrlZ:
fd := int(syscall.Stdin) fd := syscall.Stdin
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ: case CharEnter, CharCtrlJ:
output := buf.String() output := buf.String()
@ -248,7 +248,7 @@ func (i *Instance) HistoryDisable() {
} }
func NewTerminal() (*Terminal, error) { func NewTerminal() (*Terminal, error) {
fd := int(syscall.Stdin) fd := syscall.Stdin
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -987,7 +987,7 @@ func getTokenSubject(token string) string {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ { for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil { if err != nil {
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {

View file

@ -72,7 +72,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
default: default:
layers = append(layers, &layerWithGGML{layer, nil}) layers = append(layers, &layerWithGGML{layer, nil})
} }
} }
return layers, nil return layers, nil

View file

@ -6,12 +6,13 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestGetBlobsPath(t *testing.T) { func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist // GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test") dir, err := os.MkdirTemp("", "ollama-test")
assert.Nil(t, err) require.NoError(t, err)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
tests := []struct { tests := []struct {
@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) {
got, err := GetBlobsPath(tc.digest) got, err := GetBlobsPath(tc.digest)
assert.ErrorIs(t, tc.err, err, tc.name) require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name) assert.Equal(t, tc.expected, got, tc.name)
}) })
} }

View file

@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool {
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
@ -942,7 +941,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
if allowedHost(host) { if allowedHost(host) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent) c.AbortWithStatus(http.StatusNoContent)
return return
} }
@ -1306,7 +1305,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),

View file

@ -15,6 +15,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@ -25,20 +26,20 @@ func createTestFile(t *testing.T, name string) string {
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), name) f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err) assert.NoError(t, err)
defer f.Close() defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err) assert.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3)) err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err) assert.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0)) err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err) assert.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0)) err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err) assert.NoError(t, err)
return f.Name() return f.Name()
} }
@ -57,12 +58,12 @@ func Test_Routes(t *testing.T) {
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := parser.ParseFile(r) modelfile, err := parser.ParseFile(r)
assert.Nil(t, err) require.NoError(t, err)
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
err = CreateModel(context.TODO(), name, "", "", modelfile, fn) err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
assert.Nil(t, err) require.NoError(t, err)
} }
testCases := []testCase{ testCases := []testCase{
@ -74,9 +75,9 @@ func Test_Routes(t *testing.T) {
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8") assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body)) assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
}, },
}, },
@ -86,17 +87,17 @@ func Test_Routes(t *testing.T) {
Path: "/api/tags", Path: "/api/tags",
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8") assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
assert.Nil(t, err) require.NoError(t, err)
var modelList api.ListResponse var modelList api.ListResponse
err = json.Unmarshal(body, &modelList) err = json.Unmarshal(body, &modelList)
assert.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, modelList.Models) assert.NotNil(t, modelList.Models)
assert.Equal(t, 0, len(modelList.Models)) assert.Empty(t, len(modelList.Models))
}, },
}, },
{ {
@ -108,16 +109,16 @@ func Test_Routes(t *testing.T) {
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8") assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
assert.Nil(t, err) require.NoError(t, err)
var modelList api.ListResponse var modelList api.ListResponse
err = json.Unmarshal(body, &modelList) err = json.Unmarshal(body, &modelList)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 1, len(modelList.Models)) assert.Len(t, modelList.Models, 1)
assert.Equal(t, modelList.Models[0].Name, "test-model:latest") assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
}, },
}, },
{ {
@ -134,7 +135,7 @@ func Test_Routes(t *testing.T) {
Stream: &stream, Stream: &stream,
} }
jsonData, err := json.Marshal(createReq) jsonData, err := json.Marshal(createReq)
assert.Nil(t, err) require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
@ -142,11 +143,11 @@ func Test_Routes(t *testing.T) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType) assert.Equal(t, "application/json", contentType)
_, err := io.ReadAll(resp.Body) _, err := io.ReadAll(resp.Body)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, resp.StatusCode, 200) assert.Equal(t, 200, resp.StatusCode)
model, err := GetModel("t-bone") model, err := GetModel("t-bone")
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName) assert.Equal(t, "t-bone:latest", model.ShortName)
}, },
}, },
@ -161,13 +162,13 @@ func Test_Routes(t *testing.T) {
Destination: "beefsteak", Destination: "beefsteak",
} }
jsonData, err := json.Marshal(copyReq) jsonData, err := json.Marshal(copyReq)
assert.Nil(t, err) require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak") model, err := GetModel("beefsteak")
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName) assert.Equal(t, "beefsteak:latest", model.ShortName)
}, },
}, },
@ -179,18 +180,18 @@ func Test_Routes(t *testing.T) {
createTestModel(t, "show-model") createTestModel(t, "show-model")
showReq := api.ShowRequest{Model: "show-model"} showReq := api.ShowRequest{Model: "show-model"}
jsonData, err := json.Marshal(showReq) jsonData, err := json.Marshal(showReq)
assert.Nil(t, err) require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8") assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
assert.Nil(t, err) require.NoError(t, err)
var showResp api.ShowResponse var showResp api.ShowResponse
err = json.Unmarshal(body, &showResp) err = json.Unmarshal(body, &showResp)
assert.Nil(t, err) require.NoError(t, err)
var params []string var params []string
paramsSplit := strings.Split(showResp.Parameters, "\n") paramsSplit := strings.Split(showResp.Parameters, "\n")
@ -221,14 +222,14 @@ func Test_Routes(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err) require.NoError(t, err)
if tc.Setup != nil { if tc.Setup != nil {
tc.Setup(t, req) tc.Setup(t, req)
} }
resp, err := httpSrv.Client().Do(req) resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
if tc.Expected != nil { if tc.Expected != nil {

View file

@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
r.refMu.Lock() r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus)) gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil { if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread // TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus { for _, gpu := range r.gpus {
@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
} }
}() }()
return finished return finished
} }
type ByDuration []*runnerRef type ByDuration []*runnerRef

View file

@ -12,11 +12,10 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/envconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -53,10 +52,10 @@ func TestLoad(t *testing.T) {
} }
gpus := gpu.GpuInfoList{} gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus) s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0) require.Empty(t, req.successCh)
require.Len(t, req.errCh, 1) require.Len(t, req.errCh, 1)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 0) require.Empty(t, s.loaded)
s.loadedMu.Unlock() s.loadedMu.Unlock()
err := <-req.errCh err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible") require.Contains(t, err.Error(), "this model may be incompatible")
@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err) require.NoError(t, err)
defer f.Close() defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian) gguf := llm.NewGGUFV3(binary.LittleEndian)
@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
}, []llm.Tensor{ }, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
}) })
assert.Nil(t, err) require.NoError(t, err)
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
@ -190,8 +189,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario1a.req.successCh: case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario1a.req.errCh, 0) require.Empty(t, scenario1a.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -203,8 +202,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario1b.req.successCh: case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario1b.req.errCh, 0) require.Empty(t, scenario1b.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -221,8 +220,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario2a.req.successCh: case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario2a.req.errCh, 0) require.Empty(t, scenario2a.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -237,8 +236,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3a.req.successCh: case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario3a.req.errCh, 0) require.Empty(t, scenario3a.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -253,8 +252,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3b.req.successCh: case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario3b.req.errCh, 0) require.Empty(t, scenario3b.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -269,8 +268,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3c.req.successCh: case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario3c.req.errCh, 0) require.Empty(t, scenario3c.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -296,8 +295,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3d.req.successCh: case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv) require.Equal(t, resp.llama, scenario3d.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, scenario3d.req.errCh, 0) require.Empty(t, scenario3d.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) {
slog.Info("scenario1b") slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0) require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1) require.Len(t, errCh1b, 1)
err := <-errCh1b err := <-errCh1b
require.Contains(t, err.Error(), "server busy") require.Contains(t, err.Error(), "server busy")
@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) {
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, errCh1a, 0) require.Empty(t, errCh1a)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) {
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error // Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
require.Len(t, successCh1c, 0) require.Empty(t, successCh1c)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 0) require.Empty(t, s.loaded)
s.loadedMu.Unlock() s.loadedMu.Unlock()
require.Len(t, errCh1c, 1) require.Len(t, errCh1c, 1)
err = <-errCh1c err = <-errCh1c
@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) {
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0) require.Empty(t, s.pendingReqCh)
require.Len(t, errCh1a, 0) require.Empty(t, errCh1a)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) {
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1) require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0) require.Empty(t, s.finishedReqCh)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 0) require.Empty(t, s.loaded)
s.loadedMu.Unlock() s.loadedMu.Unlock()
// also shouldn't happen in real life // also shouldn't happen in real life
@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) {
r2.refCount = 1 r2.refCount = 1
resp = s.findRunnerToUnload() resp = s.findRunnerToUnload()
require.Equal(t, r1, resp) require.Equal(t, r1, resp)
} }
func TestNeedsReload(t *testing.T) { func TestNeedsReload(t *testing.T) {

View file

@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
case requestURL := <-b.nextURL: case requestURL := <-b.nextURL:
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := range maxRetries {
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ { for try := range maxRetries {
var resp *http.Response var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
} }
// retry uploading to the redirect URL // retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ { for try := range maxRetries {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil) err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):

View file

@ -268,7 +268,6 @@ func TestNameIsValidPart(t *testing.T) {
} }
}) })
} }
} }
func TestFilepathAllocs(t *testing.T) { func TestFilepathAllocs(t *testing.T) {
@ -382,7 +381,6 @@ func FuzzName(f *testing.F) {
t.Errorf("String() = %q; want %q", n.String(), s) t.Errorf("String() = %q; want %q", n.String(), s)
} }
} }
}) })
} }