diff --git a/api/types.go b/api/types.go index dee53d50..e9dc1889 100644 --- a/api/types.go +++ b/api/types.go @@ -264,6 +264,7 @@ func DefaultOptions() Options { UseNUMA: false, NumCtx: 2048, + NumKeep: -1, NumBatch: 512, NumGPU: 1, NumGQA: 1, diff --git a/llama/llama.go b/llama/llama.go index c18dc952..88421e53 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -189,10 +189,6 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) tokens[i] = C.llama_token(ctx[i]) } - if len(tokens) == 0 { - tokens = llm.tokenize(" ") - } - llm.marshalPrompt(tokens, prompt) C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) @@ -208,7 +204,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) return err } - b.WriteString(llm.detokenize(token)) + b.WriteString(llm.Decode(token)) if err := llm.checkStopConditions(b); err != nil { if errors.Is(err, io.EOF) { @@ -226,17 +222,15 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) } } - last := make([]int, 0, len(llm.last)) - for _, i := range llm.last { - if i != 0 { - last = append(last, int(i)) - } + embd := make([]int, len(llm.embd)) + for i := range llm.embd { + embd[i] = int(llm.embd[i]) } timings := C.llama_get_timings(llm.ctx) fn(api.GenerateResponse{ Done: true, - Context: last, + Context: embd, SampleCount: int(timings.n_sample), SampleDuration: parseDurationMs(float64(timings.t_sample_ms)), PromptEvalCount: int(timings.n_p_eval), @@ -261,7 +255,7 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error { } func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { - tokens := append(ctx, llm.tokenize(prompt)...) + tokens := append(ctx, llm.Encode(prompt)...) if llm.NumKeep < 0 { llm.NumKeep = len(tokens) } @@ -303,7 +297,7 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke return tokens } -func (llm *LLM) tokenize(prompt string) []C.llama_token { +func (llm *LLM) Encode(prompt string) []C.llama_token { cPrompt := C.CString(prompt) defer C.free(unsafe.Pointer(cPrompt)) @@ -315,7 +309,7 @@ func (llm *LLM) tokenize(prompt string) []C.llama_token { return nil } -func (llm *LLM) detokenize(tokens ...C.llama_token) string { +func (llm *LLM) Decode(tokens ...C.llama_token) string { var sb strings.Builder for _, token := range tokens { sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) diff --git a/server/routes.go b/server/routes.go index 5e8a356f..0023072c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -78,6 +78,25 @@ func GenerateHandler(c *gin.Context) { return } + if opts.NumKeep < 0 { + promptWithSystem, err := model.Prompt(api.GenerateRequest{}) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + tokensWithSystem := llm.Encode(promptWithSystem) + tokensNoSystem := llm.Encode(promptNoSystem) + + llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 + } + loaded.llm = llm loaded.digest = model.Digest loaded.options = opts