Merge pull request #306 from jmorganca/default-keep-system

automatically set num_keep if num_keep < 0
This commit is contained in:
Michael Yang 2023-08-08 09:25:34 -07:00 committed by GitHub
commit f2074ed4c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 14 deletions

View file

@ -266,6 +266,7 @@ func DefaultOptions() Options {
UseNUMA: false, UseNUMA: false,
NumCtx: 2048, NumCtx: 2048,
NumKeep: -1,
NumBatch: 512, NumBatch: 512,
NumGPU: 1, NumGPU: 1,
NumGQA: 1, NumGQA: 1,

View file

@ -189,10 +189,6 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
tokens[i] = C.llama_token(ctx[i]) tokens[i] = C.llama_token(ctx[i])
} }
if len(tokens) == 0 {
tokens = llm.tokenize(" ")
}
llm.marshalPrompt(tokens, prompt) llm.marshalPrompt(tokens, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
@ -208,7 +204,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return err return err
} }
b.WriteString(llm.detokenize(token)) b.WriteString(llm.Decode(token))
if err := llm.checkStopConditions(b); err != nil { if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -226,17 +222,15 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
} }
} }
last := make([]int, 0, len(llm.last)) embd := make([]int, len(llm.embd))
for _, i := range llm.last { for i := range llm.embd {
if i != 0 { embd[i] = int(llm.embd[i])
last = append(last, int(i))
}
} }
timings := C.llama_get_timings(llm.ctx) timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: last, Context: embd,
SampleCount: int(timings.n_sample), SampleCount: int(timings.n_sample),
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)), SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
PromptEvalCount: int(timings.n_p_eval), PromptEvalCount: int(timings.n_p_eval),
@ -261,7 +255,7 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
} }
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
tokens := append(ctx, llm.tokenize(prompt)...) tokens := append(ctx, llm.Encode(prompt)...)
if llm.NumKeep < 0 { if llm.NumKeep < 0 {
llm.NumKeep = len(tokens) llm.NumKeep = len(tokens)
} }
@ -303,7 +297,7 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
return tokens return tokens
} }
func (llm *LLM) tokenize(prompt string) []C.llama_token { func (llm *LLM) Encode(prompt string) []C.llama_token {
cPrompt := C.CString(prompt) cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt)) defer C.free(unsafe.Pointer(cPrompt))
@ -315,7 +309,7 @@ func (llm *LLM) tokenize(prompt string) []C.llama_token {
return nil return nil
} }
func (llm *LLM) detokenize(tokens ...C.llama_token) string { func (llm *LLM) Decode(tokens ...C.llama_token) string {
var sb strings.Builder var sb strings.Builder
for _, token := range tokens { for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))

View file

@ -78,6 +78,25 @@ func GenerateHandler(c *gin.Context) {
return return
} }
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
tokensWithSystem := llm.Encode(promptWithSystem)
tokensNoSystem := llm.Encode(promptNoSystem)
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
}
loaded.llm = llm loaded.llm = llm
loaded.digest = model.Digest loaded.digest = model.Digest
loaded.options = opts loaded.options = opts