From 3bb45f16589cfe3649330f78114237f64c8f5080 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Mon, 10 Apr 2023 16:38:45 +0200 Subject: [PATCH 1/2] More reasonable defaults --- examples/low_level_api/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index f16980c..58a5688 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -50,7 +50,7 @@ class GptParams: # If chat ended prematurely, append this to the conversation to fix it. # Set to "\nUser:" etc. # This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:"" - fix_prefix: str = " " + fix_prefix: str = "" output_postfix: str = "" input_echo: bool = True, From 53d17ad0033e99d9ac5c3fb4855710383fb1f202 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Mon, 17 Apr 2023 14:45:28 +0200 Subject: [PATCH 2/2] Fixed end of text wrong type, and fix n_predict behaviour --- examples/low_level_api/common.py | 2 +- examples/low_level_api/low_level_api_chat_cpp.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index 58a5688..061ec3a 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -75,7 +75,7 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None): parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p") parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k") parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp") - parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict") + parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict") parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n") parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty") parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch") diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index a61a55e..d64ee8f 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -144,6 +144,7 @@ specified) expect poor results""", file=sys.stderr) # determine newline token self.llama_token_newline = self._tokenize("\n", False) + self.llama_token_eot = self._tokenize(" [end of text]\n", False) if (self.params.verbose_prompt): print(f""" @@ -203,16 +204,16 @@ n_keep = {self.params.n_keep} _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) return _arr[:_n] - def use_antiprompt(self): - return len(self.first_antiprompt) > 0 - def set_color(self, c): if (self.params.use_color): print(c, end="") + def use_antiprompt(self): + return len(self.first_antiprompt) > 0 + # generate tokens def generate(self): - while self.remaining_tokens > 0 or self.params.interactive: + while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1: # predict if len(self.embd) > 0: # infinite text generation via context swapping @@ -313,7 +314,7 @@ n_keep = {self.params.n_keep} # end of text token if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): if (not self.params.instruct): - for i in " [end of text]\n": + for i in self.llama_token_eot: yield i break