From fadf75f99d7a2c3cc1be7bfc60c8bb8eb3da56f6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 27 Jul 2023 11:27:49 -0700 Subject: [PATCH] add stop conditions --- api/types.go | 27 ++++++++++++++------------- llama/llama.go | 25 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/api/types.go b/api/types.go index 9e5991dc..e8f91270 100644 --- a/api/types.go +++ b/api/types.go @@ -165,19 +165,20 @@ type Options struct { EmbeddingOnly bool `json:"embedding_only,omitempty"` // Predict options - RepeatLastN int `json:"repeat_last_n,omitempty"` - RepeatPenalty float32 `json:"repeat_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float32 `json:"top_p,omitempty"` - TFSZ float32 `json:"tfs_z,omitempty"` - TypicalP float32 `json:"typical_p,omitempty"` - Mirostat int `json:"mirostat,omitempty"` - MirostatTau float32 `json:"mirostat_tau,omitempty"` - MirostatEta float32 `json:"mirostat_eta,omitempty"` - PenalizeNewline bool `json:"penalize_newline,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + TFSZ float32 `json:"tfs_z,omitempty"` + TypicalP float32 `json:"typical_p,omitempty"` + Mirostat int `json:"mirostat,omitempty"` + MirostatTau float32 `json:"mirostat_tau,omitempty"` + MirostatEta float32 `json:"mirostat_eta,omitempty"` + PenalizeNewline bool `json:"penalize_newline,omitempty"` + StopConditions []string `json:"stop_conditions,omitempty"` NumThread int `json:"num_thread,omitempty"` } diff --git a/llama/llama.go b/llama/llama.go index c7bf194a..04e679a0 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -172,6 +172,8 @@ func (llm *LLM) Close() { C.llama_print_timings(llm.ctx) } +var errNeedMoreData = errors.New("need more data") + func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { C.llama_reset_timings(llm.ctx) @@ -200,6 +202,17 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) } b.WriteString(llm.detokenize(token)) + + if err := llm.checkStopConditions(b); err != nil { + if errors.Is(err, io.EOF) { + break + } else if errors.Is(err, errNeedMoreData) { + continue + } + + return err + } + if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { fn(api.GenerateResponse{Response: b.String()}) b.Reset() @@ -228,6 +241,18 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) return nil } +func (llm *LLM) checkStopConditions(b bytes.Buffer) error { + for _, stopCondition := range llm.StopConditions { + if stopCondition == b.String() { + return io.EOF + } else if strings.HasPrefix(stopCondition, b.String()) { + return errNeedMoreData + } + } + + return nil +} + func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { tokens := append(ctx, llm.tokenize(prompt)...) if llm.NumKeep < 0 {