Compare commits
40 commits
87345eda1b
...
1d125ce9b7
Author | SHA1 | Date | |
---|---|---|---|
1d125ce9b7 | |||
|
80ee9b5e47 | ||
|
5534f2cc6a | ||
|
d321297d8a | ||
|
06e5d74e34 | ||
|
5d707e6fd5 | ||
|
283948c83b | ||
|
1475eab95f | ||
|
20090f3172 | ||
|
69a2d4ccff | ||
|
e8b954c646 | ||
|
c57317cbf0 | ||
|
51b2fd299c | ||
|
d0634b1596 | ||
|
43606d6d6a | ||
|
70b1010fa5 | ||
|
84e5721f3a | ||
|
319fb1ce03 | ||
|
b255445557 | ||
|
b23424bb3c | ||
|
5fd6988126 | ||
|
5b82960df8 | ||
|
cc9a252d8c | ||
|
d281a6e603 | ||
|
154f6f45d4 | ||
|
0d41623b52 | ||
|
c279f96371 | ||
|
499e87c9ba | ||
|
cd0853f2d5 | ||
|
d290e87513 | ||
|
97c20ede33 | ||
|
5a83f79afd | ||
|
987dbab0b0 | ||
|
a8388beb94 | ||
|
5afbb60fc4 | ||
|
4cb5d7decc | ||
|
8eac50dd4f | ||
|
4a565cbf94 | ||
|
73e2c8f68f | ||
|
f4408219e9 |
31 changed files with 1827 additions and 442 deletions
|
@ -294,6 +294,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||||
|
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
|
83
api/types.go
83
api/types.go
|
@ -47,6 +47,9 @@ type GenerateRequest struct {
|
||||||
// Prompt is the textual prompt to send to the model.
|
// Prompt is the textual prompt to send to the model.
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// Suffix is the text that comes after the inserted text.
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
|
|
||||||
// System overrides the model's default system message/prompt.
|
// System overrides the model's default system message/prompt.
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
|
@ -98,48 +101,29 @@ type ChatRequest struct {
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tools []Tool
|
||||||
|
|
||||||
|
func (t Tools) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// Message is a single message in a chat sequence. The message contains the
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
type Alias Message
|
type Alias Message
|
||||||
var a Alias
|
var a Alias
|
||||||
|
@ -152,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function ToolCallFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunctionArguments map[string]any
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function ToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunction) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
|
@ -404,9 +428,6 @@ type GenerateResponse struct {
|
||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
// ToolCalls is the list of tools the model wants to call
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
|
15
docs/gpu.md
15
docs/gpu.md
|
@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
|
||||||
|
|
||||||
## AMD Radeon
|
## AMD Radeon
|
||||||
Ollama supports the following AMD GPUs:
|
Ollama supports the following AMD GPUs:
|
||||||
|
|
||||||
|
### Linux Support
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||||
|
|
||||||
### Overrides
|
### Windows Support
|
||||||
|
With ROCm v6.1, the following GPUs are supported on Windows.
|
||||||
|
|
||||||
|
| Family | Cards and accelerators |
|
||||||
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||||
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
|
||||||
|
|
||||||
|
### Overrides on Linux
|
||||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||||
some cases you can force the system to try to use a similar LLVM target that is
|
some cases you can force the system to try to use a similar LLVM target that is
|
||||||
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
||||||
|
@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
|
||||||
server. If you have an unsupported AMD GPU you can experiment using the list of
|
server. If you have an unsupported AMD GPU you can experiment using the list of
|
||||||
supported types below.
|
supported types below.
|
||||||
|
|
||||||
At this time, the known supported GPU types are the following LLVM Targets.
|
At this time, the known supported GPU types on linux are the following LLVM Targets.
|
||||||
This table shows some example GPUs that map to these LLVM targets:
|
This table shows some example GPUs that map to these LLVM targets:
|
||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
|
|
|
@ -103,10 +103,6 @@ curl http://localhost:11434/v1/chat/completions \
|
||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
- [ ] `n`
|
- [ ] `n`
|
||||||
|
|
||||||
#### Notes
|
|
||||||
|
|
||||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
Before using a model, pull it locally `ollama pull`:
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
|
|
@ -33,9 +33,10 @@ type HipLib struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHipLib() (*HipLib, error) {
|
func NewHipLib() (*HipLib, error) {
|
||||||
h, err := windows.LoadLibrary("amdhip64.dll")
|
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
|
||||||
|
h, err := windows.LoadLibrary("amdhip64_6.dll")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
|
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
||||||
}
|
}
|
||||||
hl := &HipLib{}
|
hl := &HipLib{}
|
||||||
hl.dll = h
|
hl.dll = h
|
||||||
|
|
|
@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
if !slices.Contains[[]string, string](supported, gfx) {
|
// Strip off Target Features when comparing
|
||||||
|
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
||||||
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
// Longer needed for small footprint GPUs
|
// Longer needed for small footprint GPUs
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
|
@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
|
}
|
||||||
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ function amdGPUs {
|
||||||
return $env:AMDGPU_TARGETS
|
return $env:AMDGPU_TARGETS
|
||||||
}
|
}
|
||||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||||
|
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
|
||||||
$GPU_LIST = @(
|
$GPU_LIST = @(
|
||||||
"gfx906:xnack-"
|
|
||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1100"
|
"gfx1100"
|
||||||
"gfx1101"
|
"gfx1101"
|
||||||
|
|
|
@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
|
||||||
"tokenizer.ggml.add_bos_token",
|
"tokenizer.ggml.add_bos_token",
|
||||||
"tokenizer.ggml.add_eos_token",
|
"tokenizer.ggml.add_eos_token",
|
||||||
"tokenizer.chat_template",
|
"tokenizer.chat_template",
|
||||||
|
"bert.pooling_type",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
43
llm/patches/10-tekken.diff
Normal file
43
llm/patches/10-tekken.diff
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
|
index bb4b05ba..a92174e0 100644
|
||||||
|
--- a/include/llama.h
|
||||||
|
+++ b/include/llama.h
|
||||||
|
@@ -92,6 +92,7 @@ extern "C" {
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||||
|
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
// note: these values should be synchronized with ggml_rope
|
||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 18364976..435b6fe5 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -5429,6 +5429,12 @@ static void llm_load_vocab(
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "jais") {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||||
|
+ } else if (
|
||||||
|
+ tokenizer_pre == "tekken") {
|
||||||
|
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
|
||||||
|
+ vocab.tokenizer_clean_spaces = false;
|
||||||
|
+ vocab.tokenizer_ignore_merges = true;
|
||||||
|
+ vocab.tokenizer_add_bos = true;
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe {
|
||||||
|
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
|
||||||
|
+ // original regex from tokenizer.json
|
||||||
|
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
+ regex_exprs = {
|
||||||
|
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
+ };
|
||||||
|
+ break;
|
||||||
|
default:
|
||||||
|
// default regex for BPE tokenization pre-processing
|
||||||
|
regex_exprs = {
|
19
llm/patches/11-embd_kv.diff
Normal file
19
llm/patches/11-embd_kv.diff
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 2b9ace28..e60d3d8d 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -6052,10 +6052,10 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
- layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||||
|
- layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
- layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
- layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
||||||
|
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||||
|
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||||
|
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||||
|
|
||||||
|
// optional bias tensors
|
||||||
|
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
|
@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
if strings.HasPrefix(ev, "CUDA_") ||
|
if strings.HasPrefix(ev, "CUDA_") ||
|
||||||
|
strings.HasPrefix(ev, "ROCR_") ||
|
||||||
strings.HasPrefix(ev, "ROCM_") ||
|
strings.HasPrefix(ev, "ROCM_") ||
|
||||||
strings.HasPrefix(ev, "HIP_") ||
|
strings.HasPrefix(ev, "HIP_") ||
|
||||||
|
strings.HasPrefix(ev, "GPU_") ||
|
||||||
strings.HasPrefix(ev, "HSA_") ||
|
strings.HasPrefix(ev, "HSA_") ||
|
||||||
strings.HasPrefix(ev, "GGML_") ||
|
strings.HasPrefix(ev, "GGML_") ||
|
||||||
strings.HasPrefix(ev, "PATH=") ||
|
strings.HasPrefix(ev, "PATH=") ||
|
||||||
|
|
174
openai/openai.go
174
openai/openai.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -31,6 +32,7 @@ type ErrorResponse struct {
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
|
@ -61,6 +63,11 @@ type ResponseFormat struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedRequest struct {
|
||||||
|
Input any `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
@ -73,6 +80,7 @@ type ChatCompletionRequest struct {
|
||||||
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||||
TopP *float64 `json:"top_p"`
|
TopP *float64 `json:"top_p"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||||
|
Tools []api.Tool `json:"tools"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion struct {
|
type ChatCompletion struct {
|
||||||
|
@ -106,6 +114,7 @@ type CompletionRequest struct {
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
|
@ -127,6 +136,15 @@ type CompletionChunk struct {
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
|
@ -134,11 +152,23 @@ type Model struct {
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
type ListCompletion struct {
|
type ListCompletion struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Data []Model `json:"data"`
|
Data []Model `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingList struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
|
@ -153,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
|
||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolCallId() string {
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, 8)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return "call_" + strings.ToLower(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||||
|
for i, tc := range r.Message.ToolCalls {
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
toolCalls[i].Type = "function"
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
|
||||||
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("could not marshall function arguments to json", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
|
}
|
||||||
|
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
|
@ -162,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []Choice{{
|
Choices: []Choice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
|
@ -171,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
@ -217,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
@ -262,6 +314,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
|
if r.Embeddings != nil {
|
||||||
|
var data []Embedding
|
||||||
|
for i, e := range r.Embeddings {
|
||||||
|
data = append(data, Embedding{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: e,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{}
|
||||||
|
}
|
||||||
|
|
||||||
func toModel(r api.ShowResponse, m string) Model {
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
return Model{
|
return Model{
|
||||||
Id: m,
|
Id: m,
|
||||||
|
@ -278,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
case string:
|
case string:
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
case []any:
|
case []any:
|
||||||
message := api.Message{Role: msg.Role}
|
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -290,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Content = text
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
|
@ -322,15 +394,27 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Images = append(message.Images, img)
|
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, message)
|
|
||||||
default:
|
default:
|
||||||
|
if msg.ToolCalls == nil {
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid tool call arguments")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
|
@ -387,6 +471,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
|
Tools: r.Tools,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,6 +522,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
Prompt: r.Prompt,
|
Prompt: r.Prompt,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
|
Suffix: r.Suffix,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,6 +551,11 @@ type RetrieveWriter struct {
|
||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
|
@ -630,6 +721,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
func ListMiddleware() gin.HandlerFunc {
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
w := &ListWriter{
|
w := &ListWriter{
|
||||||
|
@ -693,6 +811,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
@ -718,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||||
chatReq, err := fromChatRequest(req)
|
chatReq, err := fromChatRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
|
|
@ -20,108 +20,59 @@ const prefix = `data:image/jpeg;base64,`
|
||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
const imageURL = prefix + image
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func prepareRequest(req *http.Request, body any) {
|
||||||
type testCase struct {
|
bodyBytes, _ := json.Marshal(body)
|
||||||
Name string
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
Method string
|
req.Header.Set("Content-Type", "application/json")
|
||||||
Path string
|
}
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
var capturedRequest *http.Request
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
|
||||||
captureRequestMiddleware := func() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
capturedRequest = c.Request
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "chat handler",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
Stop: []string{"\n", "stop"},
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
|
||||||
var genReq api.GenerateRequest
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Prompt != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := genReq.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
Name: "chat handler with image content",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
|
@ -134,58 +85,313 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
if req.Messages[1].Role != "user" {
|
||||||
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||||
router := gin.New()
|
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with tools",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||||
|
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||||
|
ID: "id",
|
||||||
|
Type: "function",
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != 200 {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||||
|
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||||
|
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||||
|
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: 2}},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
endpoint := func(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
router = gin.New()
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||||
router.Use(captureRequestMiddleware())
|
|
||||||
router.Use(tc.Handler())
|
|
||||||
router.Handle(tc.Method, tc.Path, endpoint)
|
|
||||||
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest)
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if req.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := req.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Suffix != "suffix" {
|
||||||
|
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: nil,
|
||||||
|
Stop: []int{1, 2},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "embed handler single input",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if req.Input != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler batch input",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: []string{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
input, ok := req.Input.([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected input to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[0].(string) != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[1].(string) != "World" {
|
||||||
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
|
||||||
Name: "completions handler error forwarding",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
Name: "list handler",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
|
@ -249,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
var listResp ListCompletion
|
var listResp ListCompletion
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -314,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,13 +34,19 @@ import (
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errCapabilityCompletion = errors.New("completion")
|
var (
|
||||||
|
errCapabilities = errors.New("does not support")
|
||||||
|
errCapabilityCompletion = errors.New("completion")
|
||||||
|
errCapabilityTools = errors.New("tools")
|
||||||
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
)
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CapabilityCompletion = Capability("completion")
|
CapabilityCompletion = Capability("completion")
|
||||||
CapabilityTools = Capability("tools")
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
|
@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
}
|
}
|
||||||
case CapabilityTools:
|
case CapabilityTools:
|
||||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||||
errs = append(errs, errors.New("tools"))
|
errs = append(errs, errCapabilityTools)
|
||||||
|
}
|
||||||
|
case CapabilityInsert:
|
||||||
|
vars := m.Template.Vars()
|
||||||
|
if !slices.Contains(vars, "suffix") {
|
||||||
|
errs = append(errs, errCapabilityInsert)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
|
@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if err := errors.Join(errs...); err != nil {
|
||||||
return fmt.Errorf("does not support %w", errors.Join(errs...))
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -481,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
|
if c.Name == "template" {
|
||||||
|
if _, err := template.Parse(c.Args); err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
@ -312,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
"Function": map[string]any{
|
Function: api.ToolCallFunction{
|
||||||
"Name": "@@name@@",
|
Name: "@@name@@",
|
||||||
"Arguments": "@@arguments@@",
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -325,57 +326,48 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var kv map[string]string
|
var kv map[string]any
|
||||||
// execute the subtree with placeholders to identify the keys
|
// execute the subtree with placeholders to identify the keys
|
||||||
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
|
// trim any commands that might exist in the template
|
||||||
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the keys that correspond to the name and arguments fields
|
// find the keys that correspond to the name and arguments fields
|
||||||
var name, arguments string
|
var name, arguments string
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
switch v {
|
switch v.(type) {
|
||||||
case "@@name@@":
|
case string:
|
||||||
name = k
|
name = k
|
||||||
case "@@arguments@@":
|
case map[string]any:
|
||||||
arguments = k
|
arguments = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sm []map[string]any
|
var objs []map[string]any
|
||||||
decoder := json.NewDecoder(strings.NewReader(s))
|
for offset := 0; offset < len(s); {
|
||||||
for {
|
var obj map[string]any
|
||||||
// incrementally decode the JSON into a list of JSON objects
|
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||||
// skipping over any invalid tokens
|
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
if err := decoder.Decode(&sm); err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
break
|
||||||
}
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
|
// skip over any syntax errors
|
||||||
if errors.As(err, new(*json.SyntaxError)) {
|
offset += int(syntax.Offset)
|
||||||
r := decoder.Buffered()
|
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||||
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
|
// skip over any unmarshalable types
|
||||||
break
|
offset += int(unmarshalType.Offset)
|
||||||
}
|
} else if err != nil {
|
||||||
|
slog.Error("parseToolCalls", "error", err)
|
||||||
decoder = json.NewDecoder(r)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
|
} else {
|
||||||
|
offset += int(decoder.InputOffset())
|
||||||
|
objs = append(objs, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
// break as soon as a valid object is decoded
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range sm {
|
for _, kv := range objs {
|
||||||
call := api.ToolCall{
|
var call api.ToolCall
|
||||||
ID: uuid.New().String(),
|
|
||||||
Type: "function",
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
switch k {
|
switch k {
|
||||||
case name:
|
case name:
|
||||||
|
@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
toolCalls = append(toolCalls, call)
|
toolCalls = append(toolCalls, call)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
return toolCalls, len(toolCalls) > 0
|
||||||
return toolCalls, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -136,11 +131,16 @@ func TestExecuteWithTools(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
model string
|
model string
|
||||||
output string
|
output string
|
||||||
|
ok bool
|
||||||
}{
|
}{
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`},
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
{"command-r-plus", "Action: ```json" + `
|
{"command-r-plus", "Action: ```json" + `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
@ -158,8 +158,14 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
` + "```"},
|
` + "```", true},
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"llama3-groq-tool-use", `<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||||
|
</tool_call>`, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools []api.Tool
|
var tools []api.Tool
|
||||||
|
@ -174,20 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
|
||||||
|
|
||||||
calls := []api.ToolCall{
|
calls := []api.ToolCall{
|
||||||
{
|
{
|
||||||
Type: "function",
|
Function: api.ToolCallFunction{
|
||||||
Function: function{
|
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Type: "function",
|
Function: api.ToolCallFunction{
|
||||||
Function: function{
|
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Toronto, Canada",
|
"location": "Toronto, Canada",
|
||||||
},
|
},
|
||||||
|
@ -216,18 +220,15 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
m := &Model{Template: tmpl}
|
m := &Model{Template: tmpl}
|
||||||
actual, ok := m.parseToolCalls(tt.output)
|
actual, ok := m.parseToolCalls(tt.output)
|
||||||
if !ok {
|
if ok != tt.ok {
|
||||||
t.Fatal("failed to parse tool calls")
|
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||||
}
|
|
||||||
|
|
||||||
for i := range actual {
|
|
||||||
// ID is randomly generated so clear it for comparison
|
|
||||||
actual[i].ID = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tt.ok {
|
||||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package server
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -11,14 +10,6 @@ import (
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
||||||
for range strings.Fields(s) {
|
|
||||||
tokens = append(tokens, len(tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
type expect struct {
|
type expect struct {
|
||||||
prompt string
|
prompt string
|
||||||
|
@ -192,15 +183,11 @@ func TestChatPrompt(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
|
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.prompt != prompt {
|
|
||||||
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,7 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var errRequired = errors.New("is required")
|
||||||
|
var errBadTemplate = errors.New("template error")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
|
@ -122,6 +123,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
|
if req.Suffix != "" {
|
||||||
|
caps = append(caps, CapabilityInsert)
|
||||||
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
|
@ -150,19 +155,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
prompt := req.Prompt
|
prompt := req.Prompt
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
var msgs []api.Message
|
|
||||||
if req.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
||||||
} else if m.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, i := range images {
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
||||||
|
|
||||||
tmpl := m.Template
|
tmpl := m.Template
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
tmpl, err = template.Parse(req.Template)
|
tmpl, err = template.Parse(req.Template)
|
||||||
|
@ -183,7 +175,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
b.WriteString(s)
|
b.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
var values template.Values
|
||||||
|
if req.Suffix != "" {
|
||||||
|
values.Prompt = prompt
|
||||||
|
values.Suffix = req.Suffix
|
||||||
|
} else {
|
||||||
|
var msgs []api.Message
|
||||||
|
if req.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||||
|
} else if m.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range images {
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(&b, values); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -265,11 +276,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Response = sb.String()
|
r.Response = sb.String()
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
||||||
r.ToolCalls = toolCalls
|
|
||||||
r.Response = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, r)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -604,6 +610,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||||
|
if errors.Is(err, errBadTemplate) {
|
||||||
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
|
}
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -1064,6 +1073,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
|
@ -1190,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
|
status, ok := r["status"].(int)
|
||||||
|
if !ok {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
if errorMsg, ok := r["error"].(string); ok {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
c.JSON(status, gin.H{"error": errorMsg})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -1284,7 +1298,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
if req.Tools != nil {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1310,7 +1324,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Messages[0].Role != "system" {
|
if req.Messages[0].Role != "system" && m.System != "" {
|
||||||
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1379,10 +1393,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
|
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
resp.Message.ToolCalls = toolCalls
|
resp.Message.ToolCalls = toolCalls
|
||||||
resp.Message.Content = ""
|
resp.Message.Content = ""
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
|
@ -1393,7 +1410,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errRequired):
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
|
|
|
@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromBin(t *testing.T) {
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRemovesLayers(t *testing.T) {
|
func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateUnsetsSystem(t *testing.T) {
|
func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMergeParameters(t *testing.T) {
|
func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateReplacesMessages(t *testing.T) {
|
func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateTemplateSystem(t *testing.T) {
|
func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -477,9 +491,47 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
if string(system) != "Say bye!" {
|
if string(system) != "Say bye!" {
|
||||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("incomplete template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with unclosed if", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with undefined function", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -526,6 +578,8 @@ func TestCreateLicenses(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDetectTemplate(t *testing.T) {
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
|
@ -8,12 +8,15 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
var s Server
|
var s Server
|
||||||
|
|
714
server/routes_generate_test.go
Normal file
714
server/routes_generate_test.go
Normal file
|
@ -0,0 +1,714 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRunner struct {
|
||||||
|
llm.LlamaServer
|
||||||
|
|
||||||
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
|
llm.CompletionRequest
|
||||||
|
llm.CompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
m.CompletionRequest = r
|
||||||
|
fn(m.CompletionResponse)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
|
for range strings.Fields(s) {
|
||||||
|
tokens = append(tokens, len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
|
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateChat(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.Message, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
}); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("messages", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("messages with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("messages with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
{Role: "assistant", Content: "I can help you with that."},
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Help me write tests."},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Response != content {
|
||||||
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Context == nil {
|
||||||
|
t.Errorf("expected context not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("prompt", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("prompt with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt with template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Modelfile: `FROM test
|
||||||
|
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}"""`,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt without suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
Raw: true,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -7,11 +7,14 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestList(t *testing.T) {
|
func TestList(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
|
||||||
require.Len(t, s.expiredCh, 1)
|
require.Len(t, s.expiredCh, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type bundle struct {
|
type reqBundle struct {
|
||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
ctxDone func()
|
ctxDone func()
|
||||||
srv *mockLlm
|
srv *mockLlm
|
||||||
|
@ -102,13 +102,13 @@ type bundle struct {
|
||||||
ggml *llm.GGML
|
ggml *llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return scenario.srv, nil
|
return scenario.srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
||||||
scenario := &bundle{}
|
b := &reqBundle{}
|
||||||
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
|
@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
scenario.req = &LlmRequest{
|
if duration == nil {
|
||||||
ctx: scenario.ctx,
|
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
}
|
||||||
|
b.req = &LlmRequest{
|
||||||
|
ctx: b.ctx,
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
sessionDuration: duration,
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||||
return scenario
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequests(t *testing.T) {
|
func getGpuFn() gpu.GpuInfoList {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
// Same model, same request
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
|
||||||
scenario1b.req.model = scenario1a.req.model
|
|
||||||
scenario1b.ggml = scenario1a.ggml
|
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
|
|
||||||
// simple reload of same model
|
|
||||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
|
||||||
tmpModel := *scenario1a.req.model
|
|
||||||
scenario2a.req.model = &tmpModel
|
|
||||||
scenario2a.ggml = scenario1a.ggml
|
|
||||||
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
|
|
||||||
// Multiple loaded models
|
|
||||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
|
||||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
|
||||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
|
||||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
|
||||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
s.getCpuFn = func() gpu.GpuInfoList {
|
|
||||||
|
func getCpuFn() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "cpu"}
|
g := gpu.GpuInfo{Library: "cpu"}
|
||||||
g.TotalMemory = 32 * format.GigaByte
|
g.TotalMemory = 32 * format.GigaByte
|
||||||
g.FreeMemory = 26 * format.GigaByte
|
g.FreeMemory = 26 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||||
s.pendingReqCh <- scenario1a.req
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||||
|
b.req.model = a.req.model
|
||||||
|
b.ggml = a.ggml
|
||||||
|
|
||||||
|
s.newServerFn = a.newServer
|
||||||
|
slog.Info("a")
|
||||||
|
s.pendingReqCh <- a.req
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario1a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same runner as first request due to not needing a reload
|
// Same runner as first request due to not needing a reload
|
||||||
s.newServerFn = scenario1b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario1b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario1b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
tmpModel := *a.req.model
|
||||||
|
b.req.model = &tmpModel
|
||||||
|
b.ggml = a.ggml
|
||||||
|
|
||||||
|
s.newServerFn = a.newServer
|
||||||
|
slog.Info("a")
|
||||||
|
s.pendingReqCh <- a.req
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, a.srv)
|
||||||
|
require.Empty(t, s.pendingReqCh)
|
||||||
|
require.Empty(t, a.req.errCh)
|
||||||
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger a reload
|
// Trigger a reload
|
||||||
s.newServerFn = scenario2a.newServer
|
s.newServerFn = b.newServer
|
||||||
scenario2a.req.model.AdapterPaths = []string{"new"}
|
b.req.model.AdapterPaths = []string{"new"}
|
||||||
slog.Info("scenario2a")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario2a.req
|
s.pendingReqCh <- b.req
|
||||||
// finish first two requests, so model can reload
|
// finish first two requests, so model can reload
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
scenario1a.ctxDone()
|
a.ctxDone()
|
||||||
scenario1b.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario2a.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario2a.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario2a.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
|
||||||
|
// Multiple loaded models
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
||||||
|
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
||||||
|
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||||
|
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
||||||
|
|
||||||
envconfig.MaxRunners = 1
|
envconfig.MaxRunners = 1
|
||||||
s.newServerFn = scenario3a.newServer
|
s.newServerFn = a.newServer
|
||||||
slog.Info("scenario3a")
|
slog.Info("a")
|
||||||
s.pendingReqCh <- scenario3a.req
|
s.pendingReqCh <- a.req
|
||||||
// finish prior request, so new model can load
|
s.Run(ctx)
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
scenario2a.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario3a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
|
@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
envconfig.MaxRunners = 0
|
envconfig.MaxRunners = 0
|
||||||
s.newServerFn = scenario3b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario3b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario3b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario3b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
|
@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// This is a CPU load with NumGPU = 0 so it should load
|
// This is a CPU load with NumGPU = 0 so it should load
|
||||||
s.newServerFn = scenario3c.newServer
|
s.newServerFn = c.newServer
|
||||||
slog.Info("scenario3c")
|
slog.Info("c")
|
||||||
s.pendingReqCh <- scenario3c.req
|
s.pendingReqCh <- c.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3c.req.successCh:
|
case resp := <-c.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, c.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3c.req.errCh)
|
require.Empty(t, c.req.errCh)
|
||||||
case err := <-scenario3c.req.errCh:
|
case err := <-c.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
|
@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// Try to load a model that wont fit
|
// Try to load a model that wont fit
|
||||||
s.newServerFn = scenario3d.newServer
|
s.newServerFn = d.newServer
|
||||||
slog.Info("scenario3d")
|
slog.Info("d")
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 3)
|
require.Len(t, s.loaded, 3)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||||
time.Sleep(2 * time.Millisecond)
|
time.Sleep(2 * time.Millisecond)
|
||||||
s.pendingReqCh <- scenario3d.req
|
s.pendingReqCh <- d.req
|
||||||
// finish prior request, so new model can load
|
// finish prior request, so new model can load
|
||||||
time.Sleep(6 * time.Millisecond)
|
time.Sleep(6 * time.Millisecond)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3b.ctxDone()
|
b.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3d.req.successCh:
|
case resp := <-d.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3d.srv)
|
require.Equal(t, resp.llama, d.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3d.req.errCh)
|
require.Empty(t, d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
|
||||||
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = getGpuFn
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
s.getCpuFn = getCpuFn
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
s.newServerFn = a.newServer
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
slog.Info("a")
|
||||||
return []gpu.GpuInfo{g}
|
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||||
}
|
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
|
||||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
|
@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, errCh1a)
|
require.Empty(t, errCh1a)
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
scenario1a.ctxDone()
|
a.ctxDone() // Set "a" model to idle so it can unload
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
scenario1c.req.model.ModelPath = "bad path"
|
c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("scenario1c")
|
slog.Info("c")
|
||||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Empty(t, s.loaded)
|
require.Empty(t, s.loaded)
|
||||||
|
@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
|
||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
require.Contains(t, err.Error(), "bad path")
|
require.Contains(t, err.Error(), "bad path")
|
||||||
scenario1b.ctxDone()
|
b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
|
@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
|
@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
slog.Info("sending premature expired event now")
|
slog.Info("sending premature expired event now")
|
||||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case success := <-req.successCh:
|
case success := <-req.successCh:
|
||||||
require.Equal(t, r1, success)
|
require.Equal(t, r1, success)
|
||||||
|
case err := <-req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
||||||
defer done()
|
defer done()
|
||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
|
2
server/testdata/tools/command-r-plus.gotmpl
vendored
2
server/testdata/tools/command-r-plus.gotmpl
vendored
|
@ -46,7 +46,7 @@ Action: ```json
|
||||||
{{- range .ToolCalls }}
|
{{- range .ToolCalls }}
|
||||||
{
|
{
|
||||||
"tool_name": "{{ .Function.Name }}",
|
"tool_name": "{{ .Function.Name }}",
|
||||||
"parameters": {{ json .Function.Arguments }}
|
"parameters": {{ .Function.Arguments }}
|
||||||
}
|
}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
]```
|
]```
|
||||||
|
|
4
server/testdata/tools/firefunction.gotmpl
vendored
4
server/testdata/tools/firefunction.gotmpl
vendored
|
@ -17,7 +17,7 @@ If you decide to call functions:
|
||||||
|
|
||||||
Available functions as JSON spec:
|
Available functions as JSON spec:
|
||||||
{{- if .Tools }}
|
{{- if .Tools }}
|
||||||
{{ json .Tools }}
|
{{ .Tools }}
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- range .Messages }}<|start_header_id|>
|
{{- range .Messages }}<|start_header_id|>
|
||||||
|
@ -25,7 +25,7 @@ Available functions as JSON spec:
|
||||||
{{- end }}<|end_header_id|>
|
{{- end }}<|end_header_id|>
|
||||||
{{- if .Content }}{{ .Content }}
|
{{- if .Content }}{{ .Content }}
|
||||||
{{- else if .ToolCalls }} functools[
|
{{- else if .ToolCalls }} functools[
|
||||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||||
{{- end }}]
|
{{- end }}]
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
{{- if .Messages }}
|
||||||
|
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}
|
||||||
|
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }} {{ .Function }}
|
||||||
|
{{- end }} </tools>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ if eq .Role "user" }}{{ .Content }}
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ else }}
|
||||||
|
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ end }}{{ .Response }}
|
||||||
|
{{- if .Response }}<|eot_id|>
|
||||||
|
{{- end }}
|
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||||
|
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_response>
|
||||||
|
22
|
||||||
|
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
4
server/testdata/tools/mistral.gotmpl
vendored
4
server/testdata/tools/mistral.gotmpl
vendored
|
@ -1,13 +1,13 @@
|
||||||
{{- range $index, $_ := .Messages }}
|
{{- range $index, $_ := .Messages }}
|
||||||
{{- if eq .Role "user" }}
|
{{- if eq .Role "user" }}
|
||||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST]
|
{{ end }}{{ .Content }}[/INST]
|
||||||
{{- else if eq .Role "assistant" }}
|
{{- else if eq .Role "assistant" }}
|
||||||
{{- if .Content }} {{ .Content }}</s>
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
{{- end }}]</s>
|
{{- end }}]</s>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
|
|
@ -150,7 +150,9 @@ func (t *Template) Vars() []string {
|
||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
Tools []api.Tool
|
api.Tools
|
||||||
|
Prompt string
|
||||||
|
Suffix string
|
||||||
|
|
||||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
|
@ -204,11 +206,18 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, messages := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
if v.Prompt != "" && v.Suffix != "" {
|
||||||
|
return t.Template.Execute(w, map[string]any{
|
||||||
|
"Prompt": v.Prompt,
|
||||||
|
"Suffix": v.Suffix,
|
||||||
|
"Response": "",
|
||||||
|
})
|
||||||
|
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
"Tools": v.Tools,
|
"Tools": v.Tools,
|
||||||
|
"Response": "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
cut = true
|
cut = true
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
|
@ -264,6 +274,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
|
"Response": response,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||||
|
|
||||||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"mistral assistant",
|
||||||
|
[]template{
|
||||||
|
{"no response", `[INST] {{ .Prompt }}[/INST] `},
|
||||||
|
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range $i, $m := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
|
||||||
|
{{- end }}`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
{Role: "assistant", Content: "My name is Ollama and I"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"chatml",
|
"chatml",
|
||||||
[]template{
|
[]template{
|
||||||
|
@ -359,3 +379,38 @@ Answer: `,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithSuffix(t *testing.T) {
|
||||||
|
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
values Values
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue