adjust openai chat msg processing (#5729)
This commit is contained in:
parent
d0634b1596
commit
51b2fd299c
2 changed files with 9 additions and 6 deletions
|
@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
case string:
|
case string:
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
case []any:
|
case []any:
|
||||||
message := api.Message{Role: msg.Role}
|
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Content = text
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
|
@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Images = append(message.Images, img)
|
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, message)
|
|
||||||
default:
|
default:
|
||||||
if msg.ToolCalls == nil {
|
if msg.ToolCalls == nil {
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
|
|
@ -161,8 +161,12 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
if chatReq.Messages[1].Role != "user" {
|
||||||
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in a new issue