Enable index tracking for tools - openai api support (#7888)
This commit is contained in:
parent
39e29ae5dd
commit
5f8051180e
4 changed files with 89 additions and 4 deletions
|
@ -146,6 +146,7 @@ type ToolCall struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallFunction struct {
|
type ToolCallFunction struct {
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,6 +140,7 @@ type CompletionChunk struct {
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
Index int `json:"index"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Function struct {
|
Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
@ -206,6 +207,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
toolCalls[i].ID = toolCallId()
|
toolCalls[i].ID = toolCallId()
|
||||||
toolCalls[i].Type = "function"
|
toolCalls[i].Type = "function"
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
toolCalls[i].Index = tc.Function.Index
|
||||||
|
|
||||||
args, err := json.Marshal(tc.Function.Arguments)
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chat handler error forwarding",
|
name: "chat handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
|
|
|
@ -1469,7 +1469,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var hasToolCalls bool
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
@ -1509,16 +1509,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
sb.WriteString(r.Content)
|
sb.WriteString(r.Content)
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
|
for i := range toolCalls {
|
||||||
|
toolCalls[i].Function.Index = toolCallIndex
|
||||||
|
toolCallIndex++
|
||||||
|
}
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
hasToolCalls = true
|
|
||||||
ch <- res
|
ch <- res
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
// Send any remaining content if no tool calls were detected
|
// Send any remaining content if no tool calls were detected
|
||||||
if !hasToolCalls {
|
if toolCallIndex == 0 {
|
||||||
res.Message.Content = sb.String()
|
res.Message.Content = sb.String()
|
||||||
}
|
}
|
||||||
ch <- res
|
ch <- res
|
||||||
|
|
Loading…
Reference in a new issue