ollama/api/types_test.go
Daniel Hiltgen 171796791f Adjust mmap logic for cuda windows for faster model load
On Windows, recent llama.cpp changes make mmap slower in most
cases, so default to off.  This also implements a tri-state for
use_mmap so we can detect the difference between a user provided
value of true/false, or unspecified.
2024-06-17 16:54:30 -07:00

143 lines
2.7 KiB
Go

package api
import (
"encoding/json"
"math"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
req string
exp *Duration
}{
{
name: "Positive Integer",
req: `{ "keep_alive": 42 }`,
exp: &Duration{42 * time.Second},
},
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second},
},
{
name: "Positive Integer String",
req: `{ "keep_alive": "42m" }`,
exp: &Duration{42 * time.Minute},
},
{
name: "Negative Integer",
req: `{ "keep_alive": -1 }`,
exp: &Duration{math.MaxInt64},
},
{
name: "Negative Float",
req: `{ "keep_alive": -3.14 }`,
exp: &Duration{math.MaxInt64},
},
{
name: "Negative Integer String",
req: `{ "keep_alive": "-1m" }`,
exp: &Duration{math.MaxInt64},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var dec ChatRequest
err := json.Unmarshal([]byte(test.req), &dec)
require.NoError(t, err)
assert.Equal(t, test.exp, dec.KeepAlive)
})
}
}
func TestDurationMarshalUnmarshal(t *testing.T) {
tests := []struct {
name string
input time.Duration
expected time.Duration
}{
{
"negative duration",
time.Duration(-1),
time.Duration(math.MaxInt64),
},
{
"positive duration",
42 * time.Second,
42 * time.Second,
},
{
"another positive duration",
42 * time.Minute,
42 * time.Minute,
},
{
"zero duration",
time.Duration(0),
time.Duration(0),
},
{
"max duration",
time.Duration(math.MaxInt64),
time.Duration(math.MaxInt64),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := json.Marshal(Duration{test.input})
require.NoError(t, err)
var d Duration
err = json.Unmarshal(b, &d)
require.NoError(t, err)
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
})
}
}
func TestUseMmapParsingFromJSON(t *testing.T) {
tests := []struct {
name string
req string
exp TriState
}{
{
name: "Undefined",
req: `{ }`,
exp: TriStateUndefined,
},
{
name: "True",
req: `{ "use_mmap": true }`,
exp: TriStateTrue,
},
{
name: "False",
req: `{ "use_mmap": false }`,
exp: TriStateFalse,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var oMap map[string]interface{}
err := json.Unmarshal([]byte(test.req), &oMap)
require.NoError(t, err)
opts := DefaultOptions()
err = opts.FromMap(oMap)
require.NoError(t, err)
assert.Equal(t, test.exp, opts.UseMMap)
})
}
}