Update low_level_api_llama_cpp.py to match current API (#1023)
This commit is contained in:
parent
095c650006
commit
cfd698c75c
1 changed files with 7 additions and 7 deletions
|
@ -73,7 +73,7 @@ while remaining_tokens > 0:
|
||||||
embd = []
|
embd = []
|
||||||
if len(embd_inp) <= input_consumed:
|
if len(embd_inp) <= input_consumed:
|
||||||
logits = llama_cpp.llama_get_logits(ctx)
|
logits = llama_cpp.llama_get_logits(ctx)
|
||||||
n_vocab = llama_cpp.llama_n_vocab(ctx)
|
n_vocab = llama_cpp.llama_n_vocab(model)
|
||||||
|
|
||||||
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
||||||
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||||
|
@ -83,12 +83,12 @@ while remaining_tokens > 0:
|
||||||
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||||
|
|
||||||
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
||||||
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
llama_cpp.llama_sample_repetition_penalties(ctx, candidates_p,
|
||||||
_arr,
|
_arr,
|
||||||
last_n_repeat, repeat_penalty)
|
penalty_last_n=last_n_repeat,
|
||||||
llama_cpp.llama_sample_frequency_and_presence_penalties(ctx, candidates_p,
|
penalty_repeat=repeat_penalty,
|
||||||
_arr,
|
penalty_freq=frequency_penalty,
|
||||||
last_n_repeat, frequency_penalty, presence_penalty)
|
penalty_present=presence_penalty)
|
||||||
|
|
||||||
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
|
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
|
||||||
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
|
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
|
||||||
|
|
Loading…
Reference in a new issue