From 53d17ad0033e99d9ac5c3fb4855710383fb1f202 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Mon, 17 Apr 2023 14:45:28 +0200 Subject: [PATCH] Fixed end of text wrong type, and fix n_predict behaviour --- examples/low_level_api/common.py | 2 +- examples/low_level_api/low_level_api_chat_cpp.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index 58a5688..061ec3a 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -75,7 +75,7 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None): parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p") parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k") parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp") - parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict") + parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict") parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n") parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty") parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch") diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index a61a55e..d64ee8f 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -144,6 +144,7 @@ specified) expect poor results""", file=sys.stderr) # determine newline token self.llama_token_newline = self._tokenize("\n", False) + self.llama_token_eot = self._tokenize(" [end of text]\n", False) if (self.params.verbose_prompt): print(f""" @@ -203,16 +204,16 @@ n_keep = {self.params.n_keep} _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) return _arr[:_n] - def use_antiprompt(self): - return len(self.first_antiprompt) > 0 - def set_color(self, c): if (self.params.use_color): print(c, end="") + def use_antiprompt(self): + return len(self.first_antiprompt) > 0 + # generate tokens def generate(self): - while self.remaining_tokens > 0 or self.params.interactive: + while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1: # predict if len(self.embd) > 0: # infinite text generation via context swapping @@ -313,7 +314,7 @@ n_keep = {self.params.n_keep} # end of text token if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): if (not self.params.instruct): - for i in " [end of text]\n": + for i in self.llama_token_eot: yield i break