Merge pull request #87 from SagsMug/main
Fix TypeError in low_level chat
This commit is contained in:
commit
4ce6670bbd
2 changed files with 8 additions and 7 deletions
|
@ -50,7 +50,7 @@ class GptParams:
|
||||||
# If chat ended prematurely, append this to the conversation to fix it.
|
# If chat ended prematurely, append this to the conversation to fix it.
|
||||||
# Set to "\nUser:" etc.
|
# Set to "\nUser:" etc.
|
||||||
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
|
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
|
||||||
fix_prefix: str = " "
|
fix_prefix: str = ""
|
||||||
output_postfix: str = ""
|
output_postfix: str = ""
|
||||||
input_echo: bool = True,
|
input_echo: bool = True,
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
|
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
|
||||||
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
|
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
|
||||||
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
|
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
|
||||||
parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict")
|
parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
|
||||||
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
|
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
|
||||||
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
|
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
|
||||||
|
|
|
@ -144,6 +144,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
|
|
||||||
# determine newline token
|
# determine newline token
|
||||||
self.llama_token_newline = self._tokenize("\n", False)
|
self.llama_token_newline = self._tokenize("\n", False)
|
||||||
|
self.llama_token_eot = self._tokenize(" [end of text]\n", False)
|
||||||
|
|
||||||
if (self.params.verbose_prompt):
|
if (self.params.verbose_prompt):
|
||||||
print(f"""
|
print(f"""
|
||||||
|
@ -203,16 +204,16 @@ n_keep = {self.params.n_keep}
|
||||||
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
|
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
|
||||||
return _arr[:_n]
|
return _arr[:_n]
|
||||||
|
|
||||||
def use_antiprompt(self):
|
|
||||||
return len(self.first_antiprompt) > 0
|
|
||||||
|
|
||||||
def set_color(self, c):
|
def set_color(self, c):
|
||||||
if (self.params.use_color):
|
if (self.params.use_color):
|
||||||
print(c, end="")
|
print(c, end="")
|
||||||
|
|
||||||
|
def use_antiprompt(self):
|
||||||
|
return len(self.first_antiprompt) > 0
|
||||||
|
|
||||||
# generate tokens
|
# generate tokens
|
||||||
def generate(self):
|
def generate(self):
|
||||||
while self.remaining_tokens > 0 or self.params.interactive:
|
while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1:
|
||||||
# predict
|
# predict
|
||||||
if len(self.embd) > 0:
|
if len(self.embd) > 0:
|
||||||
# infinite text generation via context swapping
|
# infinite text generation via context swapping
|
||||||
|
@ -313,7 +314,7 @@ n_keep = {self.params.n_keep}
|
||||||
# end of text token
|
# end of text token
|
||||||
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
||||||
if (not self.params.instruct):
|
if (not self.params.instruct):
|
||||||
for i in " [end of text]\n":
|
for i in self.llama_token_eot:
|
||||||
yield i
|
yield i
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue