adjust openai chat msg processing (#5729)

This commit is contained in:
royjhan 2024-07-19 11:19:20 -07:00 committed by GitHub
parent d0634b1596
commit 51b2fd299c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 6 deletions

View file

@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
case string: case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content}) messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any: case []any:
message := api.Message{Role: msg.Role}
for _, c := range content { for _, c := range content {
data, ok := c.(map[string]any) data, ok := c.(map[string]any)
if !ok { if !ok {
@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if !ok { if !ok {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Content = text messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url": case "image_url":
var url string var url string
if urlMap, ok := data["image_url"].(map[string]any); ok { if urlMap, ok := data["image_url"].(map[string]any); ok {
@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Images = append(message.Images, img)
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default: default:
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
} }
messages = append(messages, message)
default: default:
if msg.ToolCalls == nil { if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content) return nil, fmt.Errorf("invalid message content type: %T", content)

View file

@ -161,8 +161,12 @@ func TestMiddlewareRequests(t *testing.T) {
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) { if chatReq.Messages[1].Role != "user" {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
}
if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
} }
}, },
}, },