diff --git a/openai/openai.go b/openai/openai.go index 81e4011d..01864e48 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "math/rand" "net/http" "strings" @@ -29,8 +30,9 @@ type ErrorResponse struct { } type Message struct { - Role string `json:"role"` - Content any `json:"content"` + Role string `json:"role"` + Content any `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type Choice struct { @@ -78,6 +80,7 @@ type ChatCompletionRequest struct { PresencePenalty *float64 `json:"presence_penalty_penalty"` TopP *float64 `json:"top_p"` ResponseFormat *ResponseFormat `json:"response_format"` + Tools []api.Tool `json:"tools"` } type ChatCompletion struct { @@ -133,6 +136,15 @@ type CompletionChunk struct { SystemFingerprint string `json:"system_fingerprint"` } +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + type Model struct { Id string `json:"id"` Object string `json:"object"` @@ -171,7 +183,31 @@ func NewError(code int, message string) ErrorResponse { return ErrorResponse{Error{Type: etype, Message: message}} } +func toolCallId() string { + const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, 8) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return "call_" + strings.ToLower(string(b)) +} + func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { + toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) + for i, tc := range r.Message.ToolCalls { + toolCalls[i].ID = toolCallId() + toolCalls[i].Type = "function" + toolCalls[i].Function.Name = tc.Function.Name + + args, err := json.Marshal(tc.Function.Arguments) + if err != nil { + slog.Error("could not marshall function arguments to json", "error", err) + continue + } + + toolCalls[i].Function.Arguments = string(args) + } + return ChatCompletion{ Id: id, Object: "chat.completion", @@ -180,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { SystemFingerprint: "fp_ollama", Choices: []Choice{{ Index: 0, - Message: Message{Role: r.Message.Role, Content: r.Message.Content}, + Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason @@ -366,7 +402,19 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } messages = append(messages, message) default: - return nil, fmt.Errorf("invalid message content type: %T", content) + if msg.ToolCalls == nil { + return nil, fmt.Errorf("invalid message content type: %T", content) + } + + toolCalls := make([]api.ToolCall, len(msg.ToolCalls)) + for i, tc := range msg.ToolCalls { + toolCalls[i].Function.Name = tc.Function.Name + err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments) + if err != nil { + return nil, fmt.Errorf("invalid tool call arguments") + } + } + messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls}) } } @@ -424,6 +472,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { Format: format, Options: options, Stream: &r.Stream, + Tools: r.Tools, }, nil }