Merge pull request #225 from jmorganca/stop-conditions
add stop conditions
This commit is contained in:
commit
8fa477fadb
2 changed files with 39 additions and 13 deletions
27
api/types.go
27
api/types.go
|
@ -165,19 +165,20 @@ type Options struct {
|
||||||
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
||||||
|
|
||||||
// Predict options
|
// Predict options
|
||||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
TFSZ float32 `json:"tfs_z,omitempty"`
|
TFSZ float32 `json:"tfs_z,omitempty"`
|
||||||
TypicalP float32 `json:"typical_p,omitempty"`
|
TypicalP float32 `json:"typical_p,omitempty"`
|
||||||
Mirostat int `json:"mirostat,omitempty"`
|
Mirostat int `json:"mirostat,omitempty"`
|
||||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||||
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||||
|
StopConditions []string `json:"stop_conditions,omitempty"`
|
||||||
|
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -172,6 +172,8 @@ func (llm *LLM) Close() {
|
||||||
C.llama_print_timings(llm.ctx)
|
C.llama_print_timings(llm.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errNeedMoreData = errors.New("need more data")
|
||||||
|
|
||||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
C.llama_reset_timings(llm.ctx)
|
C.llama_reset_timings(llm.ctx)
|
||||||
|
|
||||||
|
@ -200,6 +202,17 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(llm.detokenize(token))
|
b.WriteString(llm.detokenize(token))
|
||||||
|
|
||||||
|
if err := llm.checkStopConditions(b); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if errors.Is(err, errNeedMoreData) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||||
fn(api.GenerateResponse{Response: b.String()})
|
fn(api.GenerateResponse{Response: b.String()})
|
||||||
b.Reset()
|
b.Reset()
|
||||||
|
@ -228,6 +241,18 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
||||||
|
for _, stopCondition := range llm.StopConditions {
|
||||||
|
if stopCondition == b.String() {
|
||||||
|
return io.EOF
|
||||||
|
} else if strings.HasPrefix(stopCondition, b.String()) {
|
||||||
|
return errNeedMoreData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
||||||
tokens := append(ctx, llm.tokenize(prompt)...)
|
tokens := append(ctx, llm.tokenize(prompt)...)
|
||||||
if llm.NumKeep < 0 {
|
if llm.NumKeep < 0 {
|
||||||
|
|
Loading…
Reference in a new issue