diff --git a/openai/openai.go b/openai/openai.go index b289d73e..88bdaec1 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -61,6 +61,11 @@ type ResponseFormat struct { Type string `json:"type"` } +type EmbedRequest struct { + Input any `json:"input"` + Model string `json:"model"` +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -134,11 +139,23 @@ type Model struct { OwnedBy string `json:"owned_by"` } +type Embedding struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` +} + type ListCompletion struct { Object string `json:"object"` Data []Model `json:"data"` } +type EmbeddingList struct { + Object string `json:"object"` + Data []Embedding `json:"data"` + Model string `json:"model"` +} + func NewError(code int, message string) ErrorResponse { var etype string switch code { @@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion { } } +func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { + if r.Embeddings != nil { + var data []Embedding + for i, e := range r.Embeddings { + data = append(data, Embedding{ + Object: "embedding", + Embedding: e, + Index: i, + }) + } + + return EmbeddingList{ + Object: "list", + Data: data, + Model: model, + } + } + + return EmbeddingList{} +} + func toModel(r api.ShowResponse, m string) Model { return Model{ Id: m, @@ -465,6 +503,11 @@ type RetrieveWriter struct { model string } +type EmbedWriter struct { + BaseWriter + model string +} + func (w *BaseWriter) writeError(code int, data []byte) (int, error) { var serr api.StatusError err := json.Unmarshal(data, &serr) @@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +func (w *EmbedWriter) writeResponse(data []byte) (int, error) { + var embedResponse api.EmbedResponse + err := json.Unmarshal(data, &embedResponse) + + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) + + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *EmbedWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + func ListMiddleware() gin.HandlerFunc { return func(c *gin.Context) { w := &ListWriter{ @@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc { id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), } + c.Writer = w + c.Next() + } +} + +func EmbeddingsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req EmbedRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Input == "" { + req.Input = []string{""} + } + + if req.Input == nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + if v, ok := req.Input.([]any); ok && len(v) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &EmbedWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + } + c.Writer = w c.Next() diff --git a/openai/openai_test.go b/openai/openai_test.go index 99f8baaf..5fc22b88 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) { } }, }, + { + Name: "embed handler single input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: "Hello", + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + if embedReq.Input != "Hello" { + t.Fatalf("expected 'Hello', got %s", embedReq.Input) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, + { + Name: "embed handler batch input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: []string{"Hello", "World"}, + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + input, ok := embedReq.Input.([]any) + + if !ok { + t.Fatalf("expected input to be a list") + } + + if input[0].(string) != "Hello" { + t.Fatalf("expected 'Hello', got %s", input[0]) + } + + if input[1].(string) != "World" { + t.Fatalf("expected 'World', got %s", input[1]) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, } gin.SetMode(gin.TestMode) diff --git a/server/routes.go b/server/routes.go index d0cbe6cc..d22a099a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler { // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)