Update low level api example

This commit is contained in:
Andrei Betlen 2023-04-01 13:02:10 -04:00
parent 5f2e822b59
commit 5e011145c5

View file

@ -35,7 +35,7 @@ remaining_tokens = n_predict
embd = []
last_n_size = 64
last_n_tokens = [0] * last_n_size
last_n_tokens_data = [0] * last_n_size
n_batch = 24
while remaining_tokens > 0:
@ -49,21 +49,21 @@ while remaining_tokens > 0:
if len(embd_inp) <= input_consumed:
id = llama_cpp.llama_sample_top_p_top_k(
ctx,
(llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens),
len(last_n_tokens),
(llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data),
len(last_n_tokens_data),
40,
0.8,
0.2,
1.0 / 0.85,
)
last_n_tokens = last_n_tokens[1:] + [id]
last_n_tokens_data = last_n_tokens_data[1:] + [id]
embd.append(id)
input_noecho = False
remaining_tokens -= 1
else:
while len(embd_inp) > input_consumed:
embd.append(embd_inp[input_consumed])
last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]]
last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
input_consumed += 1
if len(embd) >= n_batch:
break