diff --git a/openai/openai.go b/openai/openai.go index 10e5b09e..6b469da7 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -200,9 +200,9 @@ func toolCallId() string { return "call_" + strings.ToLower(string(b)) } -func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { - toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) - for i, tc := range r.Message.ToolCalls { +func toToolCalls(tc []api.ToolCall) []ToolCall { + toolCalls := make([]ToolCall, len(tc)) + for i, tc := range tc { toolCalls[i].ID = toolCallId() toolCalls[i].Type = "function" toolCalls[i].Function.Name = tc.Function.Name @@ -215,7 +215,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { toolCalls[i].Function.Arguments = string(args) } + return toolCalls +} +func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { + toolCalls := toToolCalls(r.Message.ToolCalls) return ChatCompletion{ Id: id, Object: "chat.completion", @@ -244,6 +248,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { + toolCalls := toToolCalls(r.Message.ToolCalls) return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", @@ -252,7 +257,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, - Delta: Message{Role: "assistant", Content: r.Message.Content}, + Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason diff --git a/server/model_test.go b/server/model_test.go index 304d4655..47c4728e 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) { {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, diff --git a/server/routes.go b/server/routes.go index c13cd023..d9e4fb66 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) { prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) if err != nil { + slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) + var sb strings.Builder + var hasToolCalls bool if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1492,7 +1495,34 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - ch <- res + // TODO: tool call checking and filtering should be moved outside of this callback once streaming + // however this was a simple change for now without reworking streaming logic of this (and other) + // handlers + if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { + ch <- res + return + } + + // Streaming tool calls: + // If tools are recognized, use a flag to track the sending of a tool downstream + // This ensures that content is cleared from the message on the last chunk sent + sb.WriteString(r.Content) + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + sb.Reset() + hasToolCalls = true + ch <- res + return + } + + if r.Done { + // Send any remaining content if no tool calls were detected + if !hasToolCalls { + res.Message.Content = sb.String() + } + ch <- res + } }); err != nil { ch <- gin.H{"error": err.Error()} } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 53501cc6..4bde55bb 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "sync" "testing" "time" @@ -25,10 +26,14 @@ type mockRunner struct { // CompletionRequest is only valid until the next call to Completion llm.CompletionRequest llm.CompletionResponse + CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error } -func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { +func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { m.CompletionRequest = r + if m.CompletionFn != nil { + return m.CompletionFn(ctx, r, fn) + } fn(m.CompletionResponse) return nil } @@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) { Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ -{{- if .System }}System: {{ .System }} {{ end }} -{{- if .Prompt }}User: {{ .Prompt }} {{ end }} -{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" +{{- if .Tools }} +{{ .Tools }} +{{ end }} +{{- range .Messages }} +{{- .Role }}: {{ .Content }} +{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{- end }} +{{ end }}""" `, createBinFile(t, llm.KV{ "general.architecture": "llama", "llama.block_count": uint32(1), @@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) { t.Errorf("expected status 200, got %d", w.Code) } - if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } @@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) { t.Errorf("expected status 200, got %d", w.Code) } - if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } @@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) { t.Errorf("expected status 200, got %d", w.Code) } - if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } @@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) { t.Errorf("expected status 200, got %d", w.Code) } - if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" { + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") }) + + t.Run("messages with tools (non-streaming)", func(t *testing.T) { + if w.Code != http.StatusOK { + t.Fatalf("failed to create test-system model: %d", w.Code) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + }{ + "location": { + Type: "string", + Description: "The city and state", + }, + "unit": { + Type: "string", + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + } + + mock.CompletionResponse = llm.CompletionResponse{ + Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, + Done: true, + DoneReason: "done", + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + } + + streamRequest := true + + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather in Seattle?"}, + }, + Tools: tools, + Stream: &streamRequest, + }) + + if w.Code != http.StatusOK { + var errResp struct { + Error string `json:"error"` + } + if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil { + t.Logf("Failed to decode error response: %v", err) + } else { + t.Logf("Error response: %s", errResp.Error) + } + } + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.Message.ToolCalls == nil { + t.Error("expected tool calls, got nil") + } + + expectedToolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Seattle, WA", + "unit": "celsius", + }, + }, + } + + if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" { + t.Errorf("tool call mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("messages with tools (streaming)", func(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + }{ + "location": { + Type: "string", + Description: "The city and state", + }, + "unit": { + Type: "string", + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + } + + // Simulate streaming response with multiple chunks + var wg sync.WaitGroup + wg.Add(1) + + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + defer wg.Done() + + // Send chunks with small delays to simulate streaming + responses := []llm.CompletionResponse{ + { + Content: `{"name":"get_`, + Done: false, + PromptEvalCount: 1, + PromptEvalDuration: 1, + }, + { + Content: `weather","arguments":{"location":"Seattle`, + Done: false, + PromptEvalCount: 2, + PromptEvalDuration: 1, + }, + { + Content: `, WA","unit":"celsius"}}`, + Done: true, + DoneReason: "tool_call", + PromptEvalCount: 3, + PromptEvalDuration: 1, + }, + } + + for _, resp := range responses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + fn(resp) + time.Sleep(10 * time.Millisecond) // Small delay between chunks + } + } + return nil + } + + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather in Seattle?"}, + }, + Tools: tools, + Stream: &stream, + }) + + wg.Wait() + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // Read and validate the streamed responses + decoder := json.NewDecoder(w.Body) + var finalToolCall api.ToolCall + + for { + var resp api.ChatResponse + if err := decoder.Decode(&resp); err == io.EOF { + break + } else if err != nil { + t.Fatal(err) + } + + if resp.Done { + if len(resp.Message.ToolCalls) != 1 { + t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls)) + } + finalToolCall = resp.Message.ToolCalls[0] + } + } + + expectedToolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Seattle, WA", + "unit": "celsius", + }, + }, + } + + if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" { + t.Errorf("final tool call mismatch (-got +want):\n%s", diff) + } + }) } func TestGenerate(t *testing.T) {