Add Metrics to api\embed
response (#5709)
* add prompt tokens to embed response * rm slog * metrics * types * prompt n * clean up * reset submodule * update tests * test name * list metrics
This commit is contained in:
parent
cef2c6054d
commit
1b44d873e7
6 changed files with 39 additions and 15 deletions
|
@ -267,6 +267,10 @@ type EmbedRequest struct {
|
||||||
type EmbedResponse struct {
|
type EmbedResponse struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Embeddings [][]float32 `json:"embeddings"`
|
Embeddings [][]float32 `json:"embeddings"`
|
||||||
|
|
||||||
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
|
|
|
@ -69,6 +69,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
|
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
|
||||||
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
|
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount != 8 {
|
||||||
|
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
|
@ -97,6 +101,10 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
|
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
|
||||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
|
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount != 16 {
|
||||||
|
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
|
|
7
llm/ext_server/server.cpp
vendored
7
llm/ext_server/server.cpp
vendored
|
@ -1221,6 +1221,7 @@ struct llama_server_context
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
|
{"timings", slot.get_formated_timings()},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3203,11 +3204,15 @@ int main(int argc, char **argv) {
|
||||||
|
|
||||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||||
json embeddings = json::array();
|
json embeddings = json::array();
|
||||||
|
|
||||||
|
int prompt_n = 0;
|
||||||
for (auto & elem : responses) {
|
for (auto & elem : responses) {
|
||||||
embeddings.push_back(elem.at("embedding"));
|
embeddings.push_back(elem.at("embedding"));
|
||||||
|
prompt_n += elem.at("timings").at("prompt_n").get<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the result
|
// send the result
|
||||||
json embedding_res = json{{"embedding", embeddings}};
|
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
|
||||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embed(ctx context.Context, input []string) ([][]float32, error)
|
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
|
@ -879,10 +879,11 @@ type EmbedRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedResponse struct {
|
type EmbedResponse struct {
|
||||||
Embedding [][]float32 `json:"embedding"`
|
Embedding [][]float32 `json:"embedding"`
|
||||||
|
PromptEvalCount int `json:"prompt_n"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -924,12 +925,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
|
||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var embedding EmbedResponse
|
var e EmbedResponse
|
||||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
if err := json.Unmarshal(body, &e); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return embedding.Embedding, nil
|
return &e, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|
|
@ -284,6 +284,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) EmbedHandler(c *gin.Context) {
|
func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
|
checkpointStart := time.Now()
|
||||||
var req api.EmbedRequest
|
var req api.EmbedRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -332,6 +333,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
kvData, err := getKVData(m.ModelPath, false)
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
@ -370,13 +373,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range embeddings {
|
for i, e := range embeddings.Embedding {
|
||||||
embeddings[i] = normalize(e)
|
embeddings.Embedding[i] = normalize(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbedResponse{
|
resp := api.EmbedResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Embeddings: embeddings,
|
Embeddings: embeddings.Embedding,
|
||||||
|
TotalDuration: time.Since(checkpointStart),
|
||||||
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
|
PromptEvalCount: embeddings.PromptEvalCount,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
@ -428,9 +434,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding := make([]float64, len(embeddings[0]))
|
embedding := make([]float64, len(embeddings.Embedding[0]))
|
||||||
|
|
||||||
for i, v := range embeddings[0] {
|
for i, v := range embeddings.Embedding[0] {
|
||||||
embedding[i] = float64(v)
|
embedding[i] = float64(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -709,7 +709,7 @@ type mockLlm struct {
|
||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embedResp [][]float32
|
embedResp *llm.EmbedResponse
|
||||||
embedRespErr error
|
embedRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
|
@ -727,7 +727,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
||||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embedResp, s.embedRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|
Loading…
Reference in a new issue