token repeat limit for prediction requests (#3080)

This commit is contained in:
Bruce MacDonald 2024-03-12 22:08:25 -04:00 committed by GitHub
parent a54d4a28dc
commit 3e22611200
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -228,17 +228,14 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
} }
retryNeeded := false retryNeeded := false
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
out: out:
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// This handles the request cancellation return cancelCompletion(llm, resp)
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
} else {
return nil
}
default: default:
var result C.ext_server_task_result_t var result C.ext_server_task_result_t
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result) C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
@ -261,6 +258,20 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
break out break out
} }
switch {
case strings.TrimSpace(p.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(p.Content)
tokenRepeat = 0
}
// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog.Debug("prediction aborted, token repeat limit reached")
return cancelCompletion(llm, resp)
}
if p.Content != "" { if p.Content != "" {
fn(PredictResult{ fn(PredictResult{
Content: p.Content, Content: p.Content,
@ -288,6 +299,15 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
return fmt.Errorf("max retries exceeded") return fmt.Errorf("max retries exceeded")
} }
func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
} else {
return nil
}
}
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil { if err != nil {