Fix low_level_api_chat_cpp example to match current API (#1086)
* Fix low_level_api_chat_cpp to match current API * Fix low_level_api_chat_cpp to match current API * Using None instead of empty string to so that default prompt template can be used if no prompt provided --------- Co-authored-by: Anil Pathak <anil@heyday.com>
This commit is contained in:
parent
c689ccc728
commit
1eaace8ea3
2 changed files with 37 additions and 15 deletions
examples/low_level_api
|
@ -106,7 +106,7 @@ def gpt_params_parse(argv = None):
|
||||||
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
|
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
|
||||||
|
|
||||||
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
||||||
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
|
parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
|
||||||
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
|
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
|
||||||
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
|
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
|
||||||
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
|
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
|
||||||
|
|
|
@ -62,7 +62,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
self.multibyte_fix = []
|
self.multibyte_fix = []
|
||||||
|
|
||||||
# model load
|
# model load
|
||||||
self.lparams = llama_cpp.llama_context_default_params()
|
self.lparams = llama_cpp.llama_model_default_params()
|
||||||
self.lparams.n_ctx = self.params.n_ctx
|
self.lparams.n_ctx = self.params.n_ctx
|
||||||
self.lparams.n_parts = self.params.n_parts
|
self.lparams.n_parts = self.params.n_parts
|
||||||
self.lparams.seed = self.params.seed
|
self.lparams.seed = self.params.seed
|
||||||
|
@ -72,7 +72,11 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
|
|
||||||
self.model = llama_cpp.llama_load_model_from_file(
|
self.model = llama_cpp.llama_load_model_from_file(
|
||||||
self.params.model.encode("utf8"), self.lparams)
|
self.params.model.encode("utf8"), self.lparams)
|
||||||
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
|
|
||||||
|
# Context Params.
|
||||||
|
self.cparams = llama_cpp.llama_context_default_params()
|
||||||
|
|
||||||
|
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
|
||||||
if (not self.ctx):
|
if (not self.ctx):
|
||||||
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
||||||
|
|
||||||
|
@ -244,7 +248,7 @@ n_keep = {self.params.n_keep}
|
||||||
# tokenize a prompt
|
# tokenize a prompt
|
||||||
def _tokenize(self, prompt, bos=True):
|
def _tokenize(self, prompt, bos=True):
|
||||||
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
||||||
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
|
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
|
||||||
return _arr[:_n]
|
return _arr[:_n]
|
||||||
|
|
||||||
def set_color(self, c):
|
def set_color(self, c):
|
||||||
|
@ -304,7 +308,7 @@ n_keep = {self.params.n_keep}
|
||||||
self.n_past += n_eval"""
|
self.n_past += n_eval"""
|
||||||
|
|
||||||
if (llama_cpp.llama_eval(
|
if (llama_cpp.llama_eval(
|
||||||
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
|
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
|
||||||
) != 0):
|
) != 0):
|
||||||
raise Exception("Failed to llama_eval!")
|
raise Exception("Failed to llama_eval!")
|
||||||
|
|
||||||
|
@ -332,7 +336,7 @@ n_keep = {self.params.n_keep}
|
||||||
id = 0
|
id = 0
|
||||||
|
|
||||||
logits = llama_cpp.llama_get_logits(self.ctx)
|
logits = llama_cpp.llama_get_logits(self.ctx)
|
||||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
n_vocab = llama_cpp.llama_n_vocab(self.model)
|
||||||
|
|
||||||
# Apply params.logit_bias map
|
# Apply params.logit_bias map
|
||||||
for key, value in self.params.logit_bias.items():
|
for key, value in self.params.logit_bias.items():
|
||||||
|
@ -349,12 +353,20 @@ n_keep = {self.params.n_keep}
|
||||||
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
||||||
|
|
||||||
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
||||||
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
|
llama_cpp.llama_sample_repetition_penalties(
|
||||||
_arr,
|
ctx=self.ctx,
|
||||||
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
|
candidates=candidates_p,
|
||||||
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
|
last_tokens_data = _arr,
|
||||||
_arr,
|
penalty_last_n = last_n_repeat,
|
||||||
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
|
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
|
||||||
|
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
|
||||||
|
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOT PRESENT IN CURRENT VERSION ?
|
||||||
|
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
|
||||||
|
# _arr,
|
||||||
|
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
|
||||||
|
|
||||||
if not self.params.penalize_nl:
|
if not self.params.penalize_nl:
|
||||||
logits[llama_cpp.llama_token_nl()] = nl_logit
|
logits[llama_cpp.llama_token_nl()] = nl_logit
|
||||||
|
@ -473,7 +485,7 @@ n_keep = {self.params.n_keep}
|
||||||
def token_to_str(self, token_id: int) -> bytes:
|
def token_to_str(self, token_id: int) -> bytes:
|
||||||
size = 32
|
size = 32
|
||||||
buffer = (ctypes.c_char * size)()
|
buffer = (ctypes.c_char * size)()
|
||||||
n = llama_cpp.llama_token_to_piece_with_model(
|
n = llama_cpp.llama_token_to_piece(
|
||||||
self.model, llama_cpp.llama_token(token_id), buffer, size)
|
self.model, llama_cpp.llama_token(token_id), buffer, size)
|
||||||
assert n <= size
|
assert n <= size
|
||||||
return bytes(buffer[:n])
|
return bytes(buffer[:n])
|
||||||
|
@ -532,6 +544,9 @@ n_keep = {self.params.n_keep}
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
||||||
self.params.input_echo = False
|
self.params.input_echo = False
|
||||||
|
|
||||||
|
# Using string instead of tokens to check for antiprompt,
|
||||||
|
# It is more reliable than tokens for interactive mode.
|
||||||
|
generated_str = ""
|
||||||
while self.params.interactive:
|
while self.params.interactive:
|
||||||
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
|
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
|
||||||
if (self.params.instruct):
|
if (self.params.instruct):
|
||||||
|
@ -546,6 +561,10 @@ n_keep = {self.params.n_keep}
|
||||||
try:
|
try:
|
||||||
for i in self.output():
|
for i in self.output():
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
||||||
|
generated_str += i
|
||||||
|
for ap in self.params.antiprompt:
|
||||||
|
if generated_str.endswith(ap):
|
||||||
|
raise KeyboardInterrupt
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
if not self.params.instruct:
|
if not self.params.instruct:
|
||||||
|
@ -561,7 +580,7 @@ if __name__ == "__main__":
|
||||||
time_now = datetime.now()
|
time_now = datetime.now()
|
||||||
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
|
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
|
||||||
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
|
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
|
||||||
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
|
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
|
||||||
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
|
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
|
||||||
The transcript only includes text, it does not include markup like HTML and Markdown.
|
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||||
|
|
||||||
|
@ -575,8 +594,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||||
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||||
{USER_NAME}: Name a color.
|
{USER_NAME}: Name a color.
|
||||||
{AI_NAME}: Blue
|
{AI_NAME}: Blue
|
||||||
{USER_NAME}:"""
|
{USER_NAME}: """
|
||||||
|
|
||||||
params = gpt_params_parse()
|
params = gpt_params_parse()
|
||||||
|
if params.prompt is None and params.file is None:
|
||||||
|
params.prompt = prompt
|
||||||
|
|
||||||
with LLaMAInteract(params) as m:
|
with LLaMAInteract(params) as m:
|
||||||
m.interact()
|
m.interact()
|
||||||
|
|
Loading…
Reference in a new issue