Merge pull request #306 from jmorganca/default-keep-system
automatically set num_keep if num_keep < 0
This commit is contained in:
commit
f2074ed4c0
3 changed files with 28 additions and 14 deletions
|
@ -266,6 +266,7 @@ func DefaultOptions() Options {
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
|
|
||||||
NumCtx: 2048,
|
NumCtx: 2048,
|
||||||
|
NumKeep: -1,
|
||||||
NumBatch: 512,
|
NumBatch: 512,
|
||||||
NumGPU: 1,
|
NumGPU: 1,
|
||||||
NumGQA: 1,
|
NumGQA: 1,
|
||||||
|
|
|
@ -189,10 +189,6 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
tokens[i] = C.llama_token(ctx[i])
|
tokens[i] = C.llama_token(ctx[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
tokens = llm.tokenize(" ")
|
|
||||||
}
|
|
||||||
|
|
||||||
llm.marshalPrompt(tokens, prompt)
|
llm.marshalPrompt(tokens, prompt)
|
||||||
|
|
||||||
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
||||||
|
@ -208,7 +204,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(llm.detokenize(token))
|
b.WriteString(llm.Decode(token))
|
||||||
|
|
||||||
if err := llm.checkStopConditions(b); err != nil {
|
if err := llm.checkStopConditions(b); err != nil {
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
|
@ -226,17 +222,15 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
last := make([]int, 0, len(llm.last))
|
embd := make([]int, len(llm.embd))
|
||||||
for _, i := range llm.last {
|
for i := range llm.embd {
|
||||||
if i != 0 {
|
embd[i] = int(llm.embd[i])
|
||||||
last = append(last, int(i))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
timings := C.llama_get_timings(llm.ctx)
|
timings := C.llama_get_timings(llm.ctx)
|
||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: last,
|
Context: embd,
|
||||||
SampleCount: int(timings.n_sample),
|
SampleCount: int(timings.n_sample),
|
||||||
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
|
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
|
||||||
PromptEvalCount: int(timings.n_p_eval),
|
PromptEvalCount: int(timings.n_p_eval),
|
||||||
|
@ -261,7 +255,7 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
||||||
tokens := append(ctx, llm.tokenize(prompt)...)
|
tokens := append(ctx, llm.Encode(prompt)...)
|
||||||
if llm.NumKeep < 0 {
|
if llm.NumKeep < 0 {
|
||||||
llm.NumKeep = len(tokens)
|
llm.NumKeep = len(tokens)
|
||||||
}
|
}
|
||||||
|
@ -303,7 +297,7 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
func (llm *LLM) Encode(prompt string) []C.llama_token {
|
||||||
cPrompt := C.CString(prompt)
|
cPrompt := C.CString(prompt)
|
||||||
defer C.free(unsafe.Pointer(cPrompt))
|
defer C.free(unsafe.Pointer(cPrompt))
|
||||||
|
|
||||||
|
@ -315,7 +309,7 @@ func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
func (llm *LLM) Decode(tokens ...C.llama_token) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
||||||
|
|
|
@ -78,6 +78,25 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.NumKeep < 0 {
|
||||||
|
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensWithSystem := llm.Encode(promptWithSystem)
|
||||||
|
tokensNoSystem := llm.Encode(promptNoSystem)
|
||||||
|
|
||||||
|
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
||||||
|
}
|
||||||
|
|
||||||
loaded.llm = llm
|
loaded.llm = llm
|
||||||
loaded.digest = model.Digest
|
loaded.digest = model.Digest
|
||||||
loaded.options = opts
|
loaded.options = opts
|
||||||
|
|
Loading…
Reference in a new issue