OpenAI: Support Tools (#5614)

* reopen pr

* tools

* remove tc from stream for now

* ID and Function

* openai expects arguments to be a string (#5739)

* mutually exclusive content and tool calls

* clean up

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
royjhan 2024-07-16 20:52:59 -07:00 committed by GitHub
parent 0d41623b52
commit 154f6f45d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
@ -29,8 +30,9 @@ type ErrorResponse struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
type Choice struct { type Choice struct {
@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty_penalty"` PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"` TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"` ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
} }
type ChatCompletion struct { type ChatCompletion struct {
@ -133,6 +136,15 @@ type CompletionChunk struct {
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
} }
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct { type Model struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@ -171,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}} return ErrorResponse{Error{Type: etype, Message: message}}
} }
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@ -180,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []Choice{{ Choices: []Choice{{
Index: 0, Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content}, Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@ -366,7 +402,19 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
messages = append(messages, message) messages = append(messages, message)
default: default:
return nil, fmt.Errorf("invalid message content type: %T", content) if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
} }
} }
@ -424,6 +472,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools,
}, nil }, nil
} }