server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)
For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
This commit is contained in:
parent
25906d72d1
commit
15c2d8fe14
4 changed files with 53 additions and 71 deletions
42
llm/ext_server/server.cpp
vendored
42
llm/ext_server/server.cpp
vendored
|
@ -1223,9 +1223,7 @@ struct llama_server_context
|
||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"id", res.id},
|
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
{"timings", slot.get_formated_timings()},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3194,41 +3192,17 @@ int main(int argc, char **argv) {
|
||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prompt.size() == 1) {
|
|
||||||
prompt = prompt[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses;
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
{
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
const int id_task = llama.queue_tasks.get_new_id();
|
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
|
||||||
llama.queue_results.add_waiting_task_id(id_task);
|
|
||||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(id_task);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
llama.queue_results.remove_waiting_task_id(id_task);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
if (result.error) {
|
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
|
|
||||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
// send the result
|
||||||
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
return a["id"] < b["id"];
|
|
||||||
});
|
|
||||||
|
|
||||||
json embeddings = json::array();
|
|
||||||
|
|
||||||
int prompt_n = 0;
|
|
||||||
for (auto & elem : responses) {
|
|
||||||
embeddings.push_back(elem.at("embedding"));
|
|
||||||
prompt_n += elem.at("timings").at("prompt_n").get<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the result
|
|
||||||
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
|
|
||||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
|
|
|
@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
|
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
|
@ -883,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Content []string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding [][]float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
PromptEvalCount int `json:"prompt_n"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
// each input will use a slot, so we need to acquire the semaphore for
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
// the number of inputs up to numParallel
|
|
||||||
slots := int64(min(len(input), s.numParallel))
|
|
||||||
if err := s.sem.Acquire(ctx, slots); err != nil {
|
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer s.sem.Release(slots)
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatusRetry(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
|
@ -910,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -937,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var e EmbedResponse
|
var e EmbeddingResponse
|
||||||
if err := json.Unmarshal(body, &e); err != nil {
|
if err := json.Unmarshal(body, &e); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &e, nil
|
return e.Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
@ -346,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
for i, s := range input {
|
for i, s := range input {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -368,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count += len(tokens)
|
||||||
|
|
||||||
input[i] = s
|
input[i] = s
|
||||||
}
|
}
|
||||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
|
||||||
if err != nil {
|
var g errgroup.Group
|
||||||
slog.Error("embedding generation failed", "error", err)
|
embeddings := make([][]float32, len(input))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
for i, text := range input {
|
||||||
return
|
g.Go(func() error {
|
||||||
|
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
embeddings[i] = normalize(embedding)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range embeddings.Embedding {
|
if err := g.Wait(); err != nil {
|
||||||
embeddings.Embedding[i] = normalize(e)
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbedResponse{
|
resp := api.EmbedResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Embeddings: embeddings.Embedding,
|
Embeddings: embeddings,
|
||||||
TotalDuration: time.Since(checkpointStart),
|
TotalDuration: time.Since(checkpointStart),
|
||||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
PromptEvalCount: embeddings.PromptEvalCount,
|
PromptEvalCount: count,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
@ -430,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding := make([]float64, len(embeddings.Embedding[0]))
|
var e []float64
|
||||||
|
for _, v := range embedding {
|
||||||
for i, v := range embeddings.Embedding[0] {
|
e = append(e, float64(v))
|
||||||
embedding[i] = float64(v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
resp := api.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: e,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
|
@ -708,8 +708,8 @@ type mockLlm struct {
|
||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embedResp *llm.EmbedResponse
|
embeddingResp []float32
|
||||||
embedRespErr error
|
embeddingRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
detokenizeResp string
|
detokenizeResp string
|
||||||
|
@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|
Loading…
Reference in a new issue