From 3e22611200e9fdc44d24976cce7b36197448c972 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 12 Mar 2024 22:08:25 -0400 Subject: [PATCH] token repeat limit for prediction requests (#3080) --- llm/dyn_ext_server.go | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index fd60f1a4..832d3c47 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -228,17 +228,14 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu } retryNeeded := false + // keep track of the last token generated, this is used to abort if the model starts looping + var lastToken string + var tokenRepeat int out: for { select { case <-ctx.Done(): - // This handles the request cancellation - C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp) - if resp.id < 0 { - return extServerResponseToErr(resp) - } else { - return nil - } + return cancelCompletion(llm, resp) default: var result C.ext_server_task_result_t C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result) @@ -261,6 +258,20 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu break out } + switch { + case strings.TrimSpace(p.Content) == lastToken: + tokenRepeat++ + default: + lastToken = strings.TrimSpace(p.Content) + tokenRepeat = 0 + } + + // 30 picked as an arbitrary max token repeat limit, modify as needed + if tokenRepeat > 30 { + slog.Debug("prediction aborted, token repeat limit reached") + return cancelCompletion(llm, resp) + } + if p.Content != "" { fn(PredictResult{ Content: p.Content, @@ -288,6 +299,15 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu return fmt.Errorf("max retries exceeded") } +func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error { + C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp) + if resp.id < 0 { + return extServerResponseToErr(resp) + } else { + return nil + } +} + func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { data, err := json.Marshal(TokenizeRequest{Content: prompt}) if err != nil {