Enable index tracking for tools - openai api support (#7888)
This commit is contained in:
parent
39e29ae5dd
commit
5f8051180e
4 changed files with 89 additions and 4 deletions
|
@ -146,6 +146,7 @@ type ToolCall struct {
|
|||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
|
|
@ -140,6 +140,7 @@ type CompletionChunk struct {
|
|||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
|
@ -206,6 +207,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
|||
toolCalls[i].ID = toolCallId()
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
toolCalls[i].Index = tc.Function.Index
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
|
|
|
@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
|
|||
Stream: &False,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
|
|
|
@ -1469,7 +1469,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var hasToolCalls bool
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
|
@ -1509,16 +1509,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
hasToolCalls = true
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if !hasToolCalls {
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
|
|
Loading…
Reference in a new issue