Support image input for OpenAI chat compatibility (#5208)

* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Support image input for OpenAI chat

* Decoding

* Fix message processing logic

* openai vision test

* type errors

* clean up

* redundant check

* merge conflicts

* merge conflicts

* merge conflicts

* flattening and smaller image

* add test

* support python and js SDKs and mandate prefixing

* clean up

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
royjhan 2024-07-13 22:07:45 -07:00 committed by GitHub
parent 057d31861e
commit e9f7f36029
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 119 additions and 6 deletions

View file

@ -3,11 +3,13 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -28,7 +30,7 @@ type ErrorResponse struct {
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content any `json:"content"`
} }
type Choice struct { type Choice struct {
@ -269,10 +271,66 @@ func toModel(r api.ShowResponse, m string) Model {
} }
} }
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
message := api.Message{Role: msg.Role}
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
message.Content = text
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
message.Images = append(message.Images, img)
default:
return nil, fmt.Errorf("invalid message format")
}
}
messages = append(messages, message)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
} }
options := make(map[string]interface{}) options := make(map[string]interface{})
@ -323,13 +381,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
format = "json" format = "json"
} }
return api.ChatRequest{ return &api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
} }, nil
} }
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@ -656,7 +714,13 @@ func ChatMiddleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }

View file

@ -2,6 +2,7 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -15,6 +16,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) { func TestMiddlewareRequests(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
@ -112,6 +117,50 @@ func TestMiddlewareRequests(t *testing.T) {
} }
}, },
}, },
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
}
},
},
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)