OpenAI: /v1/embeddings compatibility (#5285)

* OpenAI v1 models

* Empty List Testing

* Add back envconfig

* v1/models docs

* Remove Docs

* OpenAI batch embed compatibility

* merge conflicts

* integrate with api/embed

* ep

* merge conflicts

* request tests

* rm resp test

* merge conflict

* merge conflict

* test fixes

* test fn renaming

* input validation for empty string

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
royjhan 2024-07-16 13:36:08 -07:00 committed by GitHub
parent a8388beb94
commit 987dbab0b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 184 additions and 0 deletions

View file

@ -61,6 +61,11 @@ type ResponseFormat struct {
Type string `json:"type"`
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@ -134,11 +139,23 @@ type Model struct {
OwnedBy string `json:"owned_by"`
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
@ -465,6 +503,11 @@ type RetrieveWriter struct {
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()

View file

@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) {
}
},
},
{
Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
if embedReq.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
{
Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
}
gin.SetMode(gin.TestMode)

View file

@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler {
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)