From b9f5e16c8025f115abde34ff047023f4d6e34af5 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:14:24 -0700 Subject: [PATCH] Introduce `/api/embed` endpoint supporting batch embedding (#5127) * Initial Batch Embedding * Revert "Initial Batch Embedding" This reverts commit c22d54895a280b54c727279d85a5fc94defb5a29. * Initial Draft * mock up notes * api/embed draft * add server function * check normalization * clean up * normalization * playing around with truncate stuff * Truncation * Truncation * move normalization to go * Integration Test Template * Truncation Integration Tests * Clean up * use float32 * move normalize * move normalize test * refactoring * integration float32 * input handling and handler testing * Refactoring of legacy and new * clear comments * merge conflicts * touches * embedding type 64 * merge conflicts * fix hanging on single string * refactoring * test values * set context length * clean up * testing clean up * testing clean up * remove function closure * Revert "remove function closure" This reverts commit 55d48c6ed17abe42e7a122e69d603ef0c1506787. * remove function closure * remove redundant error check * clean up * more clean up * clean up --- api/client.go | 11 ++- api/types.go | 24 ++++++ integration/embed_test.go | 152 ++++++++++++++++++++++++++++++++++++++ llm/ext_server/server.cpp | 37 ++++++---- llm/server.go | 16 ++-- server/routes.go | 131 +++++++++++++++++++++++++++++++- server/routes_test.go | 103 ++++++++++++++++++++++++++ server/sched_test.go | 8 +- 8 files changed, 452 insertions(+), 30 deletions(-) create mode 100644 integration/embed_test.go diff --git a/api/client.go b/api/client.go index fccbc9ad..c59fbc42 100644 --- a/api/client.go +++ b/api/client.go @@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error { return nil } -// Embeddings generates embeddings from a model. +// Embed generates embeddings from a model. +func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + var resp EmbedResponse + if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Embeddings generates an embedding from a model. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { var resp EmbeddingResponse if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { diff --git a/api/types.go b/api/types.go index 91c97c71..bf552928 100644 --- a/api/types.go +++ b/api/types.go @@ -173,6 +173,30 @@ type Runner struct { NumThread int `json:"num_thread,omitempty"` } +// EmbedRequest is the request passed to [Client.Embed]. +type EmbedRequest struct { + // Model is the model name. + Model string `json:"model"` + + // Input is the input to embed. + Input any `json:"input"` + + // KeepAlive controls how long the model will stay loaded in memory following + // this request. + KeepAlive *Duration `json:"keep_alive,omitempty"` + + Truncate *bool `json:"truncate,omitempty"` + + // Options lists model-specific options. + Options map[string]interface{} `json:"options"` +} + +// EmbedResponse is the response from [Client.Embed]. +type EmbedResponse struct { + Model string `json:"model"` + Embeddings [][]float32 `json:"embeddings,omitempty"` +} + // EmbeddingRequest is the request passed to [Client.Embeddings]. type EmbeddingRequest struct { // Model is the model name. diff --git a/integration/embed_test.go b/integration/embed_test.go new file mode 100644 index 00000000..aeafa57b --- /dev/null +++ b/integration/embed_test.go @@ -0,0 +1,152 @@ +//go:build integration + +package integration + +import ( + "context" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestAllMiniLMEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + } + + res, err := embedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } + + if len(res.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } + + if res.Embeddings[0][0] != 0.010071031 { + t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0]) + } +} + +func TestAllMiniLMBatchEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"why is the sky blue?", "why is the grass green?"}, + } + + res, err := embedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } + + if len(res.Embeddings) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } + + if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 { + t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) + } +} + +func TestAllMiniLmEmbedTruncate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + truncTrue, truncFalse := true, false + + type testReq struct { + Name string + Request api.EmbedRequest + } + + reqs := []testReq{ + { + Name: "Target Truncation", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }, + }, + { + Name: "Default Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 1}, + }, + }, + { + Name: "Explicit Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 1}, + }, + }, + } + + res := make(map[string]*api.EmbedResponse) + + for _, req := range reqs { + response, err := embedTestHelper(ctx, t, req.Request) + if err != nil { + t.Fatalf("error: %v", err) + } + res[req.Name] = response + } + + if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { + t.Fatal("expected default request to truncate correctly") + } + + if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { + t.Fatal("expected default request and truncate true request to be the same") + } + + // check that truncate set to false returns an error if context length is exceeded + _, err := embedTestHelper(ctx, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 1}, + }) + + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("failed to pull model %s: %v", req.Model, err) + } + + response, err := client.Embed(ctx, &req) + + if err != nil { + return nil, err + } + + return response, nil +} diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 0ef3956e..e8a076c4 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3188,26 +3188,33 @@ int main(int argc, char **argv) { prompt = ""; } - json image_data; - if (body.count("image_data") != 0) { - image_data = body["image_data"]; - } - else - { - image_data = ""; + if (prompt.size() == 1) { + prompt = prompt[0]; } // create and queue the task - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1); + json responses; + { + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); + // get the result + task_result result = llama.queue_results.recv(id_task); + llama.queue_results.remove_waiting_task_id(id_task); + if (result.error) { + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + } - // send the result - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + responses = result.result_json.value("results", std::vector{result.result_json}); + json embeddings = json::array(); + for (auto & elem : responses) { + embeddings.push_back(elem.at("embedding")); + } + // send the result + json embedding_res = json{{"embedding", embeddings}}; + return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); + } }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? diff --git a/llm/server.go b/llm/server.go index ffed9fc0..36c0e0b5 100644 --- a/llm/server.go +++ b/llm/server.go @@ -33,7 +33,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, prompt string) ([]float64, error) + Embed(ctx context.Context, input []string) ([][]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return nil } -type EmbeddingRequest struct { - Content string `json:"content"` +type EmbedRequest struct { + Content []string `json:"content"` } -type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` +type EmbedResponse struct { + Embedding [][]float32 `json:"embedding"` } -func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err @@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(TokenizeRequest{Content: prompt}) + data, err := json.Marshal(EmbedRequest{Content: input}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } @@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("%s", body) } - var embedding EmbeddingResponse + var embedding EmbedResponse if err := json.Unmarshal(body, &embedding); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } diff --git a/server/routes.go b/server/routes.go index 0a00d9e2..c5c3a19c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log/slog" + "math" "net" "net/http" "net/netip" @@ -271,6 +272,121 @@ func (s *Server) GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func (s *Server) EmbedHandler(c *gin.Context) { + var req api.EmbedRequest + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + truncate := true + + if req.Truncate != nil && !*req.Truncate { + truncate = false + } + + var input []string + + switch i := req.Input.(type) { + case string: + if len(i) > 0 { + input = append(input, i) + } + case []any: + for _, v := range i { + if _, ok := v.(string); !ok { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + input = append(input, v.(string)) + } + default: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + + if len(input) == 0 { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + if err != nil { + handleScheduleError(c, req.Model, err) + return + } + + kvData, err := getKVData(m.ModelPath, false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + for i, s := range input { + tokens, err := r.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if !truncate { + c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + return + } + + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + input[i] = s + } + embeddings, err := r.Embed(c.Request.Context(), input) + + if err != nil { + slog.Error("embedding generation failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } + + for i, e := range embeddings { + embeddings[i] = normalize(e) + } + + resp := api.EmbedResponse{ + Model: req.Model, + Embeddings: embeddings, + } + c.JSON(http.StatusOK, resp) +} + +func normalize(vec []float32) []float32 { + var sum float32 + for _, v := range vec { + sum += v * v + } + + norm := float32(0.0) + if sum > 0 { + norm = float32(1.0 / math.Sqrt(float64(sum))) + } + + for i := range vec { + vec[i] *= norm + } + return vec +} + func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { @@ -293,14 +409,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) + if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } - c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding}) + embedding := make([]float64, len(embeddings[0])) + + for i, v := range embeddings[0] { + embedding[i] = float64(v) + } + + resp := api.EmbeddingResponse{ + Embedding: embedding, + } + c.JSON(http.StatusOK, resp) } func (s *Server) PullModelHandler(c *gin.Context) { @@ -919,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/pull", s.PullModelHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) + r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/create", s.CreateModelHandler) r.POST("/api/push", s.PushModelHandler) diff --git a/server/routes_test.go b/server/routes_test.go index 50eaf7e9..70622e9b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "net/http/httptest" "os" @@ -272,6 +273,73 @@ func Test_Routes(t *testing.T) { assert.Equal(t, "library", retrieveResp.OwnedBy) }, }, + { + Name: "Embed Handler Empty Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: "", + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json; charset=utf-8" { + t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + var embedResp api.EmbedResponse + err = json.Unmarshal(body, &embedResp) + if err != nil { + t.Fatal(err) + } + + if embedResp.Model != "t-bone" { + t.Fatalf("expected model t-bone, got %s", embedResp.Model) + } + + if embedResp.Embeddings != nil { + t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings) + } + }, + }, + { + Name: "Embed Handler Invalid Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: 2, + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json; charset=utf-8" { + t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) + } + _, err := io.ReadAll(resp.Body) + + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status code 400, got %d", resp.StatusCode) + } + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir()) @@ -420,3 +488,38 @@ func TestShow(t *testing.T) { t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) } } + +func TestNormalize(t *testing.T) { + type testCase struct { + input []float32 + } + + testCases := []testCase{ + {input: []float32{1}}, + {input: []float32{0, 1, 2, 3}}, + {input: []float32{0.1, 0.2, 0.3}}, + {input: []float32{-0.1, 0.2, 0.3, -0.4}}, + {input: []float32{0, 0, 0}}, + } + + isNormalized := func(vec []float32) (res bool) { + sum := 0.0 + for _, v := range vec { + sum += float64(v * v) + } + if math.Abs(sum-1) > 1e-6 { + return sum == 0 + } else { + return true + } + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + normalized := normalize(tc.input) + if !isNormalized(normalized) { + t.Errorf("Vector %v is not normalized", tc.input) + } + }) + } +} diff --git a/server/sched_test.go b/server/sched_test.go index 3fbd188a..4b000331 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -642,8 +642,8 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embeddingResp []float64 - embeddingRespErr error + embedResp [][]float32 + embedRespErr error tokenizeResp []int tokenizeRespErr error detokenizeResp string @@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { + return s.embedResp, s.embedRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { return s.tokenizeResp, s.tokenizeRespErr