Add Metrics to api\embed response (#5709)

* add prompt tokens to embed response

* rm slog

* metrics

* types

* prompt n

* clean up

* reset submodule

* update tests

* test name

* list metrics
This commit is contained in:
royjhan 2024-07-30 13:12:21 -07:00 committed by GitHub
parent cef2c6054d
commit 1b44d873e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 39 additions and 15 deletions

View file

@ -267,6 +267,10 @@ type EmbedRequest struct {
type EmbedResponse struct { type EmbedResponse struct {
Model string `json:"model"` Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"` Embeddings [][]float32 `json:"embeddings"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
} }
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].

View file

@ -69,6 +69,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0]) t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
} }
if res.PromptEvalCount != 8 {
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
}
} }
func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) {
@ -97,6 +101,10 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0]) t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
} }
if res.PromptEvalCount != 16 {
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
}
} }
func TestAllMiniLMEmbedTruncate(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) {

View file

@ -1221,6 +1221,7 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", std::vector<float>(embd, embd + n_embd)},
{"timings", slot.get_formated_timings()},
}; };
} }
} }
@ -3203,11 +3204,15 @@ int main(int argc, char **argv) {
responses = result.result_json.value("results", std::vector<json>{result.result_json}); responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array(); json embeddings = json::array();
int prompt_n = 0;
for (auto & elem : responses) { for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding")); embeddings.push_back(elem.at("embedding"));
prompt_n += elem.at("timings").at("prompt_n").get<int>();
} }
// send the result // send the result
json embedding_res = json{{"embedding", embeddings}}; json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
} }
}); });

View file

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) ([][]float32, error) Embed(ctx context.Context, input []string) (*EmbedResponse, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -879,10 +879,11 @@ type EmbedRequest struct {
} }
type EmbedResponse struct { type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_n"`
} }
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@ -924,12 +925,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var embedding EmbedResponse var e EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil { if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }
return embedding.Embedding, nil return &e, nil
} }
type TokenizeRequest struct { type TokenizeRequest struct {

View file

@ -284,6 +284,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
func (s *Server) EmbedHandler(c *gin.Context) { func (s *Server) EmbedHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.EmbedRequest var req api.EmbedRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -332,6 +333,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
kvData, err := getKVData(m.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -370,13 +373,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
for i, e := range embeddings { for i, e := range embeddings.Embedding {
embeddings[i] = normalize(e) embeddings.Embedding[i] = normalize(e)
} }
resp := api.EmbedResponse{ resp := api.EmbedResponse{
Model: req.Model, Model: req.Model,
Embeddings: embeddings, Embeddings: embeddings.Embedding,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: embeddings.PromptEvalCount,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -428,9 +434,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding := make([]float64, len(embeddings[0])) embedding := make([]float64, len(embeddings.Embedding[0]))
for i, v := range embeddings[0] { for i, v := range embeddings.Embedding[0] {
embedding[i] = float64(v) embedding[i] = float64(v)
} }

View file

@ -709,7 +709,7 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embedResp [][]float32 embedResp *llm.EmbedResponse
embedRespErr error embedRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
@ -727,7 +727,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
return s.embedResp, s.embedRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {