Update low level api example

This commit is contained in:
Andrei Betlen 2023-04-01 13:02:10 -04:00
parent 5f2e822b59
commit 5e011145c5

View file

@ -35,7 +35,7 @@ remaining_tokens = n_predict
embd = [] embd = []
last_n_size = 64 last_n_size = 64
last_n_tokens = [0] * last_n_size last_n_tokens_data = [0] * last_n_size
n_batch = 24 n_batch = 24
while remaining_tokens > 0: while remaining_tokens > 0:
@ -49,21 +49,21 @@ while remaining_tokens > 0:
if len(embd_inp) <= input_consumed: if len(embd_inp) <= input_consumed:
id = llama_cpp.llama_sample_top_p_top_k( id = llama_cpp.llama_sample_top_p_top_k(
ctx, ctx,
(llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens), (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data),
len(last_n_tokens), len(last_n_tokens_data),
40, 40,
0.8, 0.8,
0.2, 0.2,
1.0 / 0.85, 1.0 / 0.85,
) )
last_n_tokens = last_n_tokens[1:] + [id] last_n_tokens_data = last_n_tokens_data[1:] + [id]
embd.append(id) embd.append(id)
input_noecho = False input_noecho = False
remaining_tokens -= 1 remaining_tokens -= 1
else: else:
while len(embd_inp) > input_consumed: while len(embd_inp) > input_consumed:
embd.append(embd_inp[input_consumed]) embd.append(embd_inp[input_consumed])
last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]] last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
input_consumed += 1 input_consumed += 1
if len(embd) >= n_batch: if len(embd) >= n_batch:
break break