token repeat limit for prediction requests (#3080)
This commit is contained in:
parent
a54d4a28dc
commit
3e22611200
1 changed files with 27 additions and 7 deletions
|
@ -228,17 +228,14 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||||
}
|
}
|
||||||
|
|
||||||
retryNeeded := false
|
retryNeeded := false
|
||||||
|
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||||
|
var lastToken string
|
||||||
|
var tokenRepeat int
|
||||||
out:
|
out:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// This handles the request cancellation
|
return cancelCompletion(llm, resp)
|
||||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
|
||||||
if resp.id < 0 {
|
|
||||||
return extServerResponseToErr(resp)
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
var result C.ext_server_task_result_t
|
var result C.ext_server_task_result_t
|
||||||
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
||||||
|
@ -261,6 +258,20 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||||
break out
|
break out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.TrimSpace(p.Content) == lastToken:
|
||||||
|
tokenRepeat++
|
||||||
|
default:
|
||||||
|
lastToken = strings.TrimSpace(p.Content)
|
||||||
|
tokenRepeat = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||||
|
if tokenRepeat > 30 {
|
||||||
|
slog.Debug("prediction aborted, token repeat limit reached")
|
||||||
|
return cancelCompletion(llm, resp)
|
||||||
|
}
|
||||||
|
|
||||||
if p.Content != "" {
|
if p.Content != "" {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
Content: p.Content,
|
Content: p.Content,
|
||||||
|
@ -288,6 +299,15 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||||
return fmt.Errorf("max retries exceeded")
|
return fmt.Errorf("max retries exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
|
||||||
|
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||||
|
if resp.id < 0 {
|
||||||
|
return extServerResponseToErr(resp)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue