2023-09-07 21:50:47 +00:00
|
|
|
import ctypes
|
|
|
|
import os
|
2023-03-24 22:57:25 +00:00
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
import llama_cpp
|
|
|
|
|
|
|
|
N_THREADS = multiprocessing.cpu_count()
|
2023-09-07 21:50:47 +00:00
|
|
|
MODEL_PATH = os.environ.get('MODEL', "../models/7B/ggml-model.bin")
|
2023-03-24 22:57:25 +00:00
|
|
|
|
|
|
|
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
|
|
|
|
|
|
|
|
lparams = llama_cpp.llama_context_default_params()
|
2023-09-07 21:50:47 +00:00
|
|
|
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode('utf-8'), lparams)
|
|
|
|
ctx = llama_cpp.llama_new_context_with_model(model, lparams)
|
2023-03-24 22:57:25 +00:00
|
|
|
|
|
|
|
# determine the required inference memory per token:
|
|
|
|
tmp = [0, 1, 2, 3]
|
|
|
|
llama_cpp.llama_eval(ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, N_THREADS)
|
|
|
|
|
|
|
|
n_past = 0
|
|
|
|
|
|
|
|
prompt = b" " + prompt
|
|
|
|
|
|
|
|
embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))()
|
|
|
|
n_of_tok = llama_cpp.llama_tokenize(ctx, prompt, embd_inp, len(embd_inp), True)
|
|
|
|
embd_inp = embd_inp[:n_of_tok]
|
|
|
|
|
|
|
|
n_ctx = llama_cpp.llama_n_ctx(ctx)
|
|
|
|
|
|
|
|
n_predict = 20
|
|
|
|
n_predict = min(n_predict, n_ctx - len(embd_inp))
|
|
|
|
|
|
|
|
input_consumed = 0
|
|
|
|
input_noecho = False
|
|
|
|
|
|
|
|
remaining_tokens = n_predict
|
|
|
|
|
|
|
|
embd = []
|
|
|
|
last_n_size = 64
|
2023-04-01 17:02:10 +00:00
|
|
|
last_n_tokens_data = [0] * last_n_size
|
2023-03-24 22:57:25 +00:00
|
|
|
n_batch = 24
|
2023-05-04 16:33:08 +00:00
|
|
|
last_n_repeat = 64
|
|
|
|
repeat_penalty = 1
|
|
|
|
frequency_penalty = 0.0
|
|
|
|
presence_penalty = 0.0
|
2023-03-24 22:57:25 +00:00
|
|
|
|
|
|
|
while remaining_tokens > 0:
|
|
|
|
if len(embd) > 0:
|
|
|
|
llama_cpp.llama_eval(
|
|
|
|
ctx, (llama_cpp.c_int * len(embd))(*embd), len(embd), n_past, N_THREADS
|
|
|
|
)
|
|
|
|
|
|
|
|
n_past += len(embd)
|
|
|
|
embd = []
|
|
|
|
if len(embd_inp) <= input_consumed:
|
2023-05-04 16:33:08 +00:00
|
|
|
logits = llama_cpp.llama_get_logits(ctx)
|
|
|
|
n_vocab = llama_cpp.llama_n_vocab(ctx)
|
|
|
|
|
|
|
|
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
|
|
|
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
|
|
|
for token_id in range(n_vocab)
|
|
|
|
])
|
2023-09-07 21:50:47 +00:00
|
|
|
candidates_p = llama_cpp.ctypes.pointer(
|
|
|
|
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
2023-05-04 16:33:08 +00:00
|
|
|
|
|
|
|
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
|
|
|
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
|
|
|
_arr,
|
|
|
|
last_n_repeat, repeat_penalty)
|
|
|
|
llama_cpp.llama_sample_frequency_and_presence_penalties(ctx, candidates_p,
|
|
|
|
_arr,
|
|
|
|
last_n_repeat, frequency_penalty, presence_penalty)
|
|
|
|
|
2023-09-07 21:50:47 +00:00
|
|
|
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
|
|
|
|
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
|
|
|
|
llama_cpp.llama_sample_temperature(ctx, candidates_p, temp=0.2)
|
2023-05-04 16:33:08 +00:00
|
|
|
id = llama_cpp.llama_sample_token(ctx, candidates_p)
|
|
|
|
|
2023-04-01 17:02:10 +00:00
|
|
|
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
2023-03-24 22:57:25 +00:00
|
|
|
embd.append(id)
|
|
|
|
input_noecho = False
|
|
|
|
remaining_tokens -= 1
|
|
|
|
else:
|
|
|
|
while len(embd_inp) > input_consumed:
|
|
|
|
embd.append(embd_inp[input_consumed])
|
2023-04-01 17:02:10 +00:00
|
|
|
last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
|
2023-03-24 22:57:25 +00:00
|
|
|
input_consumed += 1
|
|
|
|
if len(embd) >= n_batch:
|
|
|
|
break
|
|
|
|
if not input_noecho:
|
|
|
|
for id in embd:
|
2023-09-07 21:50:47 +00:00
|
|
|
size = 32
|
|
|
|
buffer = (ctypes.c_char * size)()
|
|
|
|
n = llama_cpp.llama_token_to_piece_with_model(
|
|
|
|
model, llama_cpp.llama_token(id), buffer, size)
|
|
|
|
assert n <= size
|
2023-03-24 22:57:25 +00:00
|
|
|
print(
|
2023-09-07 21:50:47 +00:00
|
|
|
buffer[:n].decode('utf-8'),
|
2023-03-24 22:57:25 +00:00
|
|
|
end="",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
2023-09-07 21:50:47 +00:00
|
|
|
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(ctx):
|
2023-03-24 22:57:25 +00:00
|
|
|
break
|
|
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
llama_cpp.llama_print_timings(ctx)
|
|
|
|
|
|
|
|
llama_cpp.llama_free(ctx)
|