Fix low_level_api_chat_cpp example to match current API (#1086)

* Fix low_level_api_chat_cpp to match current API

* Fix low_level_api_chat_cpp to match current API

* Using None instead of empty string to so that default prompt template can be used if no prompt provided

---------

Co-authored-by: Anil Pathak <anil@heyday.com>
This commit is contained in:
anil 2024-01-15 09:46:35 -06:00 committed by GitHub
parent c689ccc728
commit 1eaace8ea3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 15 deletions

View file

@ -106,7 +106,7 @@ def gpt_params_parse(argv = None):
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")

View file

@ -62,7 +62,7 @@ specified) expect poor results""", file=sys.stderr)
self.multibyte_fix = []
# model load
self.lparams = llama_cpp.llama_context_default_params()
self.lparams = llama_cpp.llama_model_default_params()
self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed
@ -72,7 +72,11 @@ specified) expect poor results""", file=sys.stderr)
self.model = llama_cpp.llama_load_model_from_file(
self.params.model.encode("utf8"), self.lparams)
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
# Context Params.
self.cparams = llama_cpp.llama_context_default_params()
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
@ -244,7 +248,7 @@ n_keep = {self.params.n_keep}
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
return _arr[:_n]
def set_color(self, c):
@ -304,7 +308,7 @@ n_keep = {self.params.n_keep}
self.n_past += n_eval"""
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
) != 0):
raise Exception("Failed to llama_eval!")
@ -332,7 +336,7 @@ n_keep = {self.params.n_keep}
id = 0
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.model)
# Apply params.logit_bias map
for key, value in self.params.logit_bias.items():
@ -349,12 +353,20 @@ n_keep = {self.params.n_keep}
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
_arr,
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
_arr,
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
llama_cpp.llama_sample_repetition_penalties(
ctx=self.ctx,
candidates=candidates_p,
last_tokens_data = _arr,
penalty_last_n = last_n_repeat,
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
)
# NOT PRESENT IN CURRENT VERSION ?
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
# _arr,
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
if not self.params.penalize_nl:
logits[llama_cpp.llama_token_nl()] = nl_logit
@ -473,7 +485,7 @@ n_keep = {self.params.n_keep}
def token_to_str(self, token_id: int) -> bytes:
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece_with_model(
n = llama_cpp.llama_token_to_piece(
self.model, llama_cpp.llama_token(token_id), buffer, size)
assert n <= size
return bytes(buffer[:n])
@ -532,6 +544,9 @@ n_keep = {self.params.n_keep}
print(i,end="",flush=True)
self.params.input_echo = False
# Using string instead of tokens to check for antiprompt,
# It is more reliable than tokens for interactive mode.
generated_str = ""
while self.params.interactive:
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
@ -546,6 +561,10 @@ n_keep = {self.params.n_keep}
try:
for i in self.output():
print(i,end="",flush=True)
generated_str += i
for ap in self.params.antiprompt:
if generated_str.endswith(ap):
raise KeyboardInterrupt
except KeyboardInterrupt:
self.set_color(util.CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
@ -561,7 +580,7 @@ if __name__ == "__main__":
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
@ -575,8 +594,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color.
{AI_NAME}: Blue
{USER_NAME}:"""
{USER_NAME}: """
params = gpt_params_parse()
if params.prompt is None and params.file is None:
params.prompt = prompt
with LLaMAInteract(params) as m:
m.interact()