Support image input for OpenAI chat compatibility (#5208)
* OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com> * Add Middleware for Chat and List * Testing Cleanup * Test with Fatal * Add functionality to chat test * Support image input for OpenAI chat * Decoding * Fix message processing logic * openai vision test * type errors * clean up * redundant check * merge conflicts * merge conflicts * merge conflicts * flattening and smaller image * add test * support python and js SDKs and mandate prefixing * clean up --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
parent
057d31861e
commit
e9f7f36029
2 changed files with 119 additions and 6 deletions
|
@ -3,11 +3,13 @@ package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
@ -28,7 +30,7 @@ type ErrorResponse struct {
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
|
@ -269,10 +271,66 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
switch content := msg.Content.(type) {
|
||||||
|
case string:
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
|
case []any:
|
||||||
|
message := api.Message{Role: msg.Role}
|
||||||
|
for _, c := range content {
|
||||||
|
data, ok := c.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
switch data["type"] {
|
||||||
|
case "text":
|
||||||
|
text, ok := data["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Content = text
|
||||||
|
case "image_url":
|
||||||
|
var url string
|
||||||
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
|
if url, ok = urlMap["url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if url, ok = data["image_url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
types := []string{"jpeg", "jpg", "png"}
|
||||||
|
valid := false
|
||||||
|
for _, t := range types {
|
||||||
|
prefix := "data:image/" + t + ";base64,"
|
||||||
|
if strings.HasPrefix(url, prefix) {
|
||||||
|
url = strings.TrimPrefix(url, prefix)
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
return nil, fmt.Errorf("invalid image input")
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := base64.StdEncoding.DecodeString(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Images = append(message.Images, img)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, message)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
|
@ -323,13 +381,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||||
format = "json"
|
format = "json"
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
|
@ -656,7 +714,13 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
|
||||||
|
chatReq, err := fromChatRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -15,6 +16,10 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const prefix = `data:image/jpeg;base64,`
|
||||||
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func TestMiddlewareRequests(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
Name string
|
||||||
|
@ -112,6 +117,50 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with image content",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{
|
||||||
|
Role: "user", Content: []map[string]any{
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Role != "user" {
|
||||||
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
|
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
Loading…
Reference in a new issue