This commit is contained in:
commit
9fa4a19138
10 changed files with 414 additions and 38 deletions
|
@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
|
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||||
|
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||||
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||||
|
|
|
@ -146,6 +146,7 @@ type ToolCall struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallFunction struct {
|
type ToolCallFunction struct {
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,7 +105,7 @@ make apply-patches
|
||||||
|
|
||||||
**Pin to new base commit**
|
**Pin to new base commit**
|
||||||
|
|
||||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||||
|
|
||||||
#### Applying patches
|
#### Applying patches
|
||||||
|
|
||||||
|
|
|
@ -833,10 +833,21 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type multiLPath []string
|
||||||
|
|
||||||
|
func (m *multiLPath) Set(value string) error {
|
||||||
|
*m = append(*m, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *multiLPath) String() string {
|
||||||
|
return strings.Join(*m, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
params llama.ModelParams,
|
params llama.ModelParams,
|
||||||
mpath string,
|
mpath string,
|
||||||
lpath string,
|
lpath multiLPath,
|
||||||
ppath string,
|
ppath string,
|
||||||
kvSize int,
|
kvSize int,
|
||||||
flashAttention bool,
|
flashAttention bool,
|
||||||
|
@ -857,12 +868,14 @@ func (s *Server) loadModel(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if lpath != "" {
|
if lpath.String() != "" {
|
||||||
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
|
for _, path := range lpath {
|
||||||
|
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ppath != "" {
|
if ppath != "" {
|
||||||
var err error
|
var err error
|
||||||
|
@ -890,7 +903,6 @@ func main() {
|
||||||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
lpath := flag.String("lora", "", "Path to lora layer file")
|
|
||||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||||
|
@ -900,6 +912,9 @@ func main() {
|
||||||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||||
|
|
||||||
|
var lpaths multiLPath
|
||||||
|
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
if *requirements {
|
if *requirements {
|
||||||
printRequirements(os.Stdout)
|
printRequirements(os.Stdout)
|
||||||
|
@ -946,7 +961,7 @@ func main() {
|
||||||
params := llama.ModelParams{
|
params := llama.ModelParams{
|
||||||
NumGpuLayers: *nGpuLayers,
|
NumGpuLayers: *nGpuLayers,
|
||||||
MainGpu: *mainGpu,
|
MainGpu: *mainGpu,
|
||||||
UseMmap: !*noMmap && *lpath == "",
|
UseMmap: !*noMmap && lpaths.String() == "",
|
||||||
UseMlock: *mlock,
|
UseMlock: *mlock,
|
||||||
TensorSplit: tensorSplitFloats,
|
TensorSplit: tensorSplitFloats,
|
||||||
Progress: func(progress float32) {
|
Progress: func(progress float32) {
|
||||||
|
@ -955,7 +970,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
|
|
|
@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
finalErr := errors.New("no suitable llama servers found")
|
finalErr := errors.New("no suitable llama servers found")
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
rDir, err := runners.Refresh(build.EmbedFS)
|
rDir, err := runners.Refresh(build.EmbedFS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(adapters) > 0 {
|
if len(adapters) > 0 {
|
||||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
for _, adapter := range adapters {
|
||||||
params = append(params, "--lora", adapters[0])
|
params = append(params, "--lora", adapter)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(projectors) > 0 {
|
if len(projectors) > 0 {
|
||||||
|
|
|
@ -140,6 +140,7 @@ type CompletionChunk struct {
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
Index int `json:"index"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Function struct {
|
Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
@ -200,12 +201,13 @@ func toolCallId() string {
|
||||||
return "call_" + strings.ToLower(string(b))
|
return "call_" + strings.ToLower(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
toolCalls := make([]ToolCall, len(tc))
|
||||||
for i, tc := range r.Message.ToolCalls {
|
for i, tc := range tc {
|
||||||
toolCalls[i].ID = toolCallId()
|
toolCalls[i].ID = toolCallId()
|
||||||
toolCalls[i].Type = "function"
|
toolCalls[i].Type = "function"
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
toolCalls[i].Index = tc.Function.Index
|
||||||
|
|
||||||
args, err := json.Marshal(tc.Function.Arguments)
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -215,7 +217,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
|
||||||
toolCalls[i].Function.Arguments = string(args)
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
}
|
}
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
|
@ -244,6 +250,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletionChunk{
|
return ChatCompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
|
@ -252,7 +259,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []ChunkChoice{{
|
Choices: []ChunkChoice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
|
|
|
@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chat handler error forwarding",
|
name: "chat handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
|
|
|
@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
|
|
@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Error("chat prompt error", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
var sb strings.Builder
|
||||||
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
@ -1492,7 +1495,37 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||||
|
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||||
|
// handlers
|
||||||
|
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||||
ch <- res
|
ch <- res
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming tool calls:
|
||||||
|
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||||
|
// This ensures that content is cleared from the message on the last chunk sent
|
||||||
|
sb.WriteString(r.Content)
|
||||||
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
res.Message.ToolCalls = toolCalls
|
||||||
|
for i := range toolCalls {
|
||||||
|
toolCalls[i].Function.Index = toolCallIndex
|
||||||
|
toolCallIndex++
|
||||||
|
}
|
||||||
|
res.Message.Content = ""
|
||||||
|
sb.Reset()
|
||||||
|
ch <- res
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
// Send any remaining content if no tool calls were detected
|
||||||
|
if toolCallIndex == 0 {
|
||||||
|
res.Message.Content = sb.String()
|
||||||
|
}
|
||||||
|
ch <- res
|
||||||
|
}
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,10 +26,14 @@ type mockRunner struct {
|
||||||
// CompletionRequest is only valid until the next call to Completion
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
llm.CompletionRequest
|
llm.CompletionRequest
|
||||||
llm.CompletionResponse
|
llm.CompletionResponse
|
||||||
|
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
m.CompletionRequest = r
|
m.CompletionRequest = r
|
||||||
|
if m.CompletionFn != nil {
|
||||||
|
return m.CompletionFn(ctx, r, fn)
|
||||||
|
}
|
||||||
fn(m.CompletionResponse)
|
fn(m.CompletionResponse)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Model: "test",
|
Model: "test",
|
||||||
Modelfile: fmt.Sprintf(`FROM %s
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
TEMPLATE """
|
TEMPLATE """
|
||||||
{{- if .System }}System: {{ .System }} {{ end }}
|
{{- if .Tools }}
|
||||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
{{ .Tools }}
|
||||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
{{ end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- .Role }}: {{ .Content }}
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}
|
||||||
|
{{ end }}"""
|
||||||
`, createBinFile(t, llm.KV{
|
`, createBinFile(t, llm.KV{
|
||||||
"general.architecture": "llama",
|
"general.architecture": "llama",
|
||||||
"llama.block_count": uint32(1),
|
"llama.block_count": uint32(1),
|
||||||
|
@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to create test-system model: %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse = llm.CompletionResponse{
|
||||||
|
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "done",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
streamRequest := true
|
||||||
|
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Stream: &streamRequest,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
||||||
|
t.Logf("Failed to decode error response: %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("Error response: %s", errResp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Message.ToolCalls == nil {
|
||||||
|
t.Error("expected tool calls, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedToolCall := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Seattle, WA",
|
||||||
|
"unit": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
||||||
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate streaming response with multiple chunks
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Send chunks with small delays to simulate streaming
|
||||||
|
responses := []llm.CompletionResponse{
|
||||||
|
{
|
||||||
|
Content: `{"name":"get_`,
|
||||||
|
Done: false,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Content: `weather","arguments":{"location":"Seattle`,
|
||||||
|
Done: false,
|
||||||
|
PromptEvalCount: 2,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Content: `, WA","unit":"celsius"}}`,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "tool_call",
|
||||||
|
PromptEvalCount: 3,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
fn(resp)
|
||||||
|
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and validate the streamed responses
|
||||||
|
decoder := json.NewDecoder(w.Body)
|
||||||
|
var finalToolCall api.ToolCall
|
||||||
|
|
||||||
|
for {
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := decoder.Decode(&resp); err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
if len(resp.Message.ToolCalls) != 1 {
|
||||||
|
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||||
|
}
|
||||||
|
finalToolCall = resp.Message.ToolCalls[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedToolCall := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Seattle, WA",
|
||||||
|
"unit": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||||
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerate(t *testing.T) {
|
func TestGenerate(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue