Fix low_level_api_chat_cpp example to match current API ()

* Fix low_level_api_chat_cpp to match current API

* Fix low_level_api_chat_cpp to match current API

* Using None instead of empty string to so that default prompt template can be used if no prompt provided

---------

Co-authored-by: Anil Pathak <anil@heyday.com>
This commit is contained in:
anil 2024-01-15 09:46:35 -06:00 committed by GitHub
parent c689ccc728
commit 1eaace8ea3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 15 deletions

View file

@ -106,7 +106,7 @@ def gpt_params_parse(argv = None):
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta") parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model") parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt") parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file") parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session") parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix") parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")

View file

@ -62,7 +62,7 @@ specified) expect poor results""", file=sys.stderr)
self.multibyte_fix = [] self.multibyte_fix = []
# model load # model load
self.lparams = llama_cpp.llama_context_default_params() self.lparams = llama_cpp.llama_model_default_params()
self.lparams.n_ctx = self.params.n_ctx self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed self.lparams.seed = self.params.seed
@ -72,7 +72,11 @@ specified) expect poor results""", file=sys.stderr)
self.model = llama_cpp.llama_load_model_from_file( self.model = llama_cpp.llama_load_model_from_file(
self.params.model.encode("utf8"), self.lparams) self.params.model.encode("utf8"), self.lparams)
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
# Context Params.
self.cparams = llama_cpp.llama_context_default_params()
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
if (not self.ctx): if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'") raise RuntimeError(f"error: failed to load model '{self.params.model}'")
@ -244,7 +248,7 @@ n_keep = {self.params.n_keep}
# tokenize a prompt # tokenize a prompt
def _tokenize(self, prompt, bos=True): def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))() _arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos) _n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
return _arr[:_n] return _arr[:_n]
def set_color(self, c): def set_color(self, c):
@ -304,7 +308,7 @@ n_keep = {self.params.n_keep}
self.n_past += n_eval""" self.n_past += n_eval"""
if (llama_cpp.llama_eval( if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
) != 0): ) != 0):
raise Exception("Failed to llama_eval!") raise Exception("Failed to llama_eval!")
@ -332,7 +336,7 @@ n_keep = {self.params.n_keep}
id = 0 id = 0
logits = llama_cpp.llama_get_logits(self.ctx) logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.model)
# Apply params.logit_bias map # Apply params.logit_bias map
for key, value in self.params.logit_bias.items(): for key, value in self.params.logit_bias.items():
@ -349,12 +353,20 @@ n_keep = {self.params.n_keep}
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx) last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:]) _arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, llama_cpp.llama_sample_repetition_penalties(
_arr, ctx=self.ctx,
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty)) candidates=candidates_p,
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, last_tokens_data = _arr,
_arr, penalty_last_n = last_n_repeat,
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty)) penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
)
# NOT PRESENT IN CURRENT VERSION ?
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
# _arr,
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
if not self.params.penalize_nl: if not self.params.penalize_nl:
logits[llama_cpp.llama_token_nl()] = nl_logit logits[llama_cpp.llama_token_nl()] = nl_logit
@ -473,7 +485,7 @@ n_keep = {self.params.n_keep}
def token_to_str(self, token_id: int) -> bytes: def token_to_str(self, token_id: int) -> bytes:
size = 32 size = 32
buffer = (ctypes.c_char * size)() buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece_with_model( n = llama_cpp.llama_token_to_piece(
self.model, llama_cpp.llama_token(token_id), buffer, size) self.model, llama_cpp.llama_token(token_id), buffer, size)
assert n <= size assert n <= size
return bytes(buffer[:n]) return bytes(buffer[:n])
@ -532,6 +544,9 @@ n_keep = {self.params.n_keep}
print(i,end="",flush=True) print(i,end="",flush=True)
self.params.input_echo = False self.params.input_echo = False
# Using string instead of tokens to check for antiprompt,
# It is more reliable than tokens for interactive mode.
generated_str = ""
while self.params.interactive: while self.params.interactive:
self.set_color(util.CONSOLE_COLOR_USER_INPUT) self.set_color(util.CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct): if (self.params.instruct):
@ -546,6 +561,10 @@ n_keep = {self.params.n_keep}
try: try:
for i in self.output(): for i in self.output():
print(i,end="",flush=True) print(i,end="",flush=True)
generated_str += i
for ap in self.params.antiprompt:
if generated_str.endswith(ap):
raise KeyboardInterrupt
except KeyboardInterrupt: except KeyboardInterrupt:
self.set_color(util.CONSOLE_COLOR_DEFAULT) self.set_color(util.CONSOLE_COLOR_DEFAULT)
if not self.params.instruct: if not self.params.instruct:
@ -561,7 +580,7 @@ if __name__ == "__main__":
time_now = datetime.now() time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}. prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision. {AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other. Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown. The transcript only includes text, it does not include markup like HTML and Markdown.
@ -575,8 +594,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. {AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color. {USER_NAME}: Name a color.
{AI_NAME}: Blue {AI_NAME}: Blue
{USER_NAME}:""" {USER_NAME}: """
params = gpt_params_parse() params = gpt_params_parse()
if params.prompt is None and params.file is None:
params.prompt = prompt
with LLaMAInteract(params) as m: with LLaMAInteract(params) as m:
m.interact() m.interact()