OpenAI: /v1/embeddings compatibility (#5285)
* OpenAI v1 models * Empty List Testing * Add back envconfig * v1/models docs * Remove Docs * OpenAI batch embed compatibility * merge conflicts * integrate with api/embed * ep * merge conflicts * request tests * rm resp test * merge conflict * merge conflict * test fixes * test fn renaming * input validation for empty string --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
parent
a8388beb94
commit
987dbab0b0
3 changed files with 184 additions and 0 deletions
111
openai/openai.go
111
openai/openai.go
|
@ -61,6 +61,11 @@ type ResponseFormat struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedRequest struct {
|
||||||
|
Input any `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
@ -134,11 +139,23 @@ type Model struct {
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
type ListCompletion struct {
|
type ListCompletion struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Data []Model `json:"data"`
|
Data []Model `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingList struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
|
@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
|
if r.Embeddings != nil {
|
||||||
|
var data []Embedding
|
||||||
|
for i, e := range r.Embeddings {
|
||||||
|
data = append(data, Embedding{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: e,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{}
|
||||||
|
}
|
||||||
|
|
||||||
func toModel(r api.ShowResponse, m string) Model {
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
return Model{
|
return Model{
|
||||||
Id: m,
|
Id: m,
|
||||||
|
@ -465,6 +503,11 @@ type RetrieveWriter struct {
|
||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
|
@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
func ListMiddleware() gin.HandlerFunc {
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
w := &ListWriter{
|
w := &ListWriter{
|
||||||
|
@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
|
@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler single input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Input != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler batch input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: []string{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input, ok := embedReq.Input.([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected input to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[0].(string) != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[1].(string) != "World" {
|
||||||
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
|
@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue