Merge pull request #680 from jbochi/low_level_api
Fix low level api examples
This commit is contained in:
commit
b615fc3322
3 changed files with 41 additions and 18 deletions
|
@ -187,7 +187,8 @@ Below is a short example demonstrating how to use the low-level API to tokenize
|
||||||
>>> import ctypes
|
>>> import ctypes
|
||||||
>>> params = llama_cpp.llama_context_default_params()
|
>>> params = llama_cpp.llama_context_default_params()
|
||||||
# use bytes for char * params
|
# use bytes for char * params
|
||||||
>>> ctx = llama_cpp.llama_init_from_file(b"./models/7b/ggml-model.bin", params)
|
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/ggml-model.bin", params)
|
||||||
|
>>> ctx = llama_cpp.llama_new_context_with_model(model, params)
|
||||||
>>> max_tokens = params.n_ctx
|
>>> max_tokens = params.n_ctx
|
||||||
# use ctypes arrays for array params
|
# use ctypes arrays for array params
|
||||||
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()
|
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()
|
||||||
|
|
|
@ -24,6 +24,10 @@ class LLaMAInteract:
|
||||||
def __init__(self, params: GptParams) -> None:
|
def __init__(self, params: GptParams) -> None:
|
||||||
# input args
|
# input args
|
||||||
self.params = params
|
self.params = params
|
||||||
|
if self.params.path_session is None:
|
||||||
|
self.params.path_session = ""
|
||||||
|
if self.params.antiprompt is None:
|
||||||
|
self.params.antiprompt = ""
|
||||||
|
|
||||||
if (self.params.perplexity):
|
if (self.params.perplexity):
|
||||||
raise NotImplementedError("""************
|
raise NotImplementedError("""************
|
||||||
|
@ -66,7 +70,9 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
self.lparams.use_mlock = self.params.use_mlock
|
self.lparams.use_mlock = self.params.use_mlock
|
||||||
self.lparams.use_mmap = self.params.use_mmap
|
self.lparams.use_mmap = self.params.use_mmap
|
||||||
|
|
||||||
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
|
self.model = llama_cpp.llama_load_model_from_file(
|
||||||
|
self.params.model.encode("utf8"), self.lparams)
|
||||||
|
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
|
||||||
if (not self.ctx):
|
if (not self.ctx):
|
||||||
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
||||||
|
|
||||||
|
@ -181,12 +187,12 @@ prompt: '{self.params.prompt}'
|
||||||
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
|
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
|
||||||
|
|
||||||
for i in range(len(self.embd_inp)):
|
for i in range(len(self.embd_inp)):
|
||||||
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
|
print(f"{self.embd_inp[i]} -> '{self.token_to_str(self.embd_inp[i])}'", file=sys.stderr)
|
||||||
|
|
||||||
if (self.params.n_keep > 0):
|
if (self.params.n_keep > 0):
|
||||||
print("static prompt based on n_keep: '")
|
print("static prompt based on n_keep: '")
|
||||||
for i in range(self.params.n_keep):
|
for i in range(self.params.n_keep):
|
||||||
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
|
print(self.token_to_str(self.embd_inp[i]), file=sys.stderr)
|
||||||
print("'", file=sys.stderr)
|
print("'", file=sys.stderr)
|
||||||
print(file=sys.stderr)
|
print(file=sys.stderr)
|
||||||
|
|
||||||
|
@ -339,7 +345,7 @@ n_keep = {self.params.n_keep}
|
||||||
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||||
|
|
||||||
# Apply penalties
|
# Apply penalties
|
||||||
nl_logit = logits[llama_cpp.llama_token_nl()]
|
nl_logit = logits[llama_cpp.llama_token_nl(self.ctx)]
|
||||||
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
||||||
|
|
||||||
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
||||||
|
@ -380,7 +386,7 @@ n_keep = {self.params.n_keep}
|
||||||
self.last_n_tokens.append(id)
|
self.last_n_tokens.append(id)
|
||||||
|
|
||||||
# replace end of text token with newline token when in interactive mode
|
# replace end of text token with newline token when in interactive mode
|
||||||
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
|
if (id == llama_cpp.llama_token_eos(self.ctx) and self.params.interactive and not self.params.instruct):
|
||||||
id = self.llama_token_newline[0]
|
id = self.llama_token_newline[0]
|
||||||
self.embd.append(id)
|
self.embd.append(id)
|
||||||
if (self.use_antiprompt()):
|
if (self.use_antiprompt()):
|
||||||
|
@ -437,7 +443,7 @@ n_keep = {self.params.n_keep}
|
||||||
break
|
break
|
||||||
|
|
||||||
# end of text token
|
# end of text token
|
||||||
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(self.ctx):
|
||||||
if (not self.params.instruct):
|
if (not self.params.instruct):
|
||||||
for i in self.llama_token_eot:
|
for i in self.llama_token_eot:
|
||||||
yield i
|
yield i
|
||||||
|
@ -464,10 +470,18 @@ n_keep = {self.params.n_keep}
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
|
|
||||||
|
def token_to_str(self, token_id: int) -> bytes:
|
||||||
|
size = 32
|
||||||
|
buffer = (ctypes.c_char * size)()
|
||||||
|
n = llama_cpp.llama_token_to_piece_with_model(
|
||||||
|
self.model, llama_cpp.llama_token(token_id), buffer, size)
|
||||||
|
assert n <= size
|
||||||
|
return bytes(buffer[:n])
|
||||||
|
|
||||||
# return past text
|
# return past text
|
||||||
def past(self):
|
def past(self):
|
||||||
for id in self.last_n_tokens[-self.n_past:]:
|
for id in self.last_n_tokens[-self.n_past:]:
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
|
yield self.token_to_str(id).decode("utf8", errors="ignore")
|
||||||
|
|
||||||
# write input
|
# write input
|
||||||
def input(self, prompt: str):
|
def input(self, prompt: str):
|
||||||
|
@ -481,7 +495,7 @@ n_keep = {self.params.n_keep}
|
||||||
def output(self):
|
def output(self):
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
for id in self.generate():
|
for id in self.generate():
|
||||||
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
|
cur_char = self.token_to_str(id)
|
||||||
|
|
||||||
# Add remainder of missing bytes
|
# Add remainder of missing bytes
|
||||||
if None in self.multibyte_fix:
|
if None in self.multibyte_fix:
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import llama_cpp
|
import ctypes
|
||||||
|
import os
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
N_THREADS = multiprocessing.cpu_count()
|
N_THREADS = multiprocessing.cpu_count()
|
||||||
|
MODEL_PATH = os.environ.get('MODEL', "../models/7B/ggml-model.bin")
|
||||||
|
|
||||||
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
|
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
|
||||||
|
|
||||||
lparams = llama_cpp.llama_context_default_params()
|
lparams = llama_cpp.llama_context_default_params()
|
||||||
ctx = llama_cpp.llama_init_from_file(b"../models/7B/ggml-model.bin", lparams)
|
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode('utf-8'), lparams)
|
||||||
|
ctx = llama_cpp.llama_new_context_with_model(model, lparams)
|
||||||
|
|
||||||
# determine the required inference memory per token:
|
# determine the required inference memory per token:
|
||||||
tmp = [0, 1, 2, 3]
|
tmp = [0, 1, 2, 3]
|
||||||
|
@ -58,7 +60,8 @@ while remaining_tokens > 0:
|
||||||
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||||
for token_id in range(n_vocab)
|
for token_id in range(n_vocab)
|
||||||
])
|
])
|
||||||
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
candidates_p = llama_cpp.ctypes.pointer(
|
||||||
|
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||||
|
|
||||||
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
||||||
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
||||||
|
@ -68,9 +71,9 @@ while remaining_tokens > 0:
|
||||||
_arr,
|
_arr,
|
||||||
last_n_repeat, frequency_penalty, presence_penalty)
|
last_n_repeat, frequency_penalty, presence_penalty)
|
||||||
|
|
||||||
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
|
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
|
||||||
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
|
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
|
||||||
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
|
llama_cpp.llama_sample_temperature(ctx, candidates_p, temp=0.2)
|
||||||
id = llama_cpp.llama_sample_token(ctx, candidates_p)
|
id = llama_cpp.llama_sample_token(ctx, candidates_p)
|
||||||
|
|
||||||
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
||||||
|
@ -86,13 +89,18 @@ while remaining_tokens > 0:
|
||||||
break
|
break
|
||||||
if not input_noecho:
|
if not input_noecho:
|
||||||
for id in embd:
|
for id in embd:
|
||||||
|
size = 32
|
||||||
|
buffer = (ctypes.c_char * size)()
|
||||||
|
n = llama_cpp.llama_token_to_piece_with_model(
|
||||||
|
model, llama_cpp.llama_token(id), buffer, size)
|
||||||
|
assert n <= size
|
||||||
print(
|
print(
|
||||||
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
|
buffer[:n].decode('utf-8'),
|
||||||
end="",
|
end="",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos():
|
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(ctx):
|
||||||
break
|
break
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
Loading…
Reference in a new issue