From 5f8051180e3b9aeafc153f6b5056e7358a939c88 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 29 Nov 2024 20:00:09 -0800 Subject: [PATCH] Enable index tracking for tools - openai api support (#7888) --- api/types.go | 1 + openai/openai.go | 2 ++ openai/openai_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++- server/routes.go | 9 +++-- 4 files changed, 89 insertions(+), 4 deletions(-) diff --git a/api/types.go b/api/types.go index e5291a02..d2108f88 100644 --- a/api/types.go +++ b/api/types.go @@ -146,6 +146,7 @@ type ToolCall struct { } type ToolCallFunction struct { + Index int `json:"index,omitempty"` Name string `json:"name"` Arguments ToolCallFunctionArguments `json:"arguments"` } diff --git a/openai/openai.go b/openai/openai.go index 6b469da7..bf1879f9 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -140,6 +140,7 @@ type CompletionChunk struct { type ToolCall struct { ID string `json:"id"` + Index int `json:"index"` Type string `json:"type"` Function struct { Name string `json:"name"` @@ -206,6 +207,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall { toolCalls[i].ID = toolCallId() toolCalls[i].Type = "function" toolCalls[i].Function.Name = tc.Function.Name + toolCalls[i].Index = tc.Function.Index args, err := json.Marshal(tc.Function.Arguments) if err != nil { diff --git a/openai/openai_test.go b/openai/openai_test.go index eabf5b66..e17037de 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) { Stream: &False, }, }, - + { + name: "chat handler with streaming tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris?"} + ], + "stream": true, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "The city and state" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + } + } + } + }] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris?", + }, + }, + Tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + }{ + "location": { + Type: "string", + Description: "The city and state", + }, + "unit": { + Type: "string", + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &True, + }, + }, { name: "chat handler error forwarding", body: `{ diff --git a/server/routes.go b/server/routes.go index d9e4fb66..edf1f4c5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1469,7 +1469,7 @@ func (s *Server) ChatHandler(c *gin.Context) { go func() { defer close(ch) var sb strings.Builder - var hasToolCalls bool + var toolCallIndex int = 0 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1509,16 +1509,19 @@ func (s *Server) ChatHandler(c *gin.Context) { sb.WriteString(r.Content) if toolCalls, ok := m.parseToolCalls(sb.String()); ok { res.Message.ToolCalls = toolCalls + for i := range toolCalls { + toolCalls[i].Function.Index = toolCallIndex + toolCallIndex++ + } res.Message.Content = "" sb.Reset() - hasToolCalls = true ch <- res return } if r.Done { // Send any remaining content if no tool calls were detected - if !hasToolCalls { + if toolCallIndex == 0 { res.Message.Content = sb.String() } ch <- res