Fix bug in init_break not being set when exited via antiprompt and others.

This commit is contained in:
Mug 2023-04-05 14:47:24 +02:00
parent 99ceecfccd
commit 283e59c5e9

View file

@ -33,6 +33,7 @@ class LLaMAInteract:
top_p: float=1., top_p: float=1.,
temp: float=1.0, temp: float=1.0,
repeat_penalty: float=1, repeat_penalty: float=1,
init_break: bool=True,
instruct_inp_prefix: str="\n\n### Instruction:\n\n", instruct_inp_prefix: str="\n\n### Instruction:\n\n",
instruct_inp_suffix: str="\n\n### Response:\n\n", instruct_inp_suffix: str="\n\n### Response:\n\n",
) -> None: ) -> None:
@ -48,6 +49,7 @@ class LLaMAInteract:
self.top_p=top_p self.top_p=top_p
self.temp=temp self.temp=temp
self.repeat_penalty=repeat_penalty self.repeat_penalty=repeat_penalty
self.init_break = init_break
# runtime args # runtime args
self.input_consumed = 0 self.input_consumed = 0
@ -81,9 +83,6 @@ class LLaMAInteract:
if (len(primer) > 0): if (len(primer) > 0):
self.embd_inp += self._tokenize(primer) self.embd_inp += self._tokenize(primer)
# break immediately if using instruct
self.init_break = self.instruct
# number of tokens to keep when resetting context # number of tokens to keep when resetting context
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct): if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
self.n_keep = len(self.embd_inp) self.n_keep = len(self.embd_inp)
@ -182,13 +181,14 @@ class LLaMAInteract:
if (len(self.embd_inp) <= self.input_consumed): if (len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop # if antiprompt is present, stop
if (self.use_antiprompt()): if (self.use_antiprompt()):
for i in self.first_antiprompt: if True in [
if i == self.last_n_tokens[-len(i):]: i == self.last_n_tokens[-len(i):]
return for i in self.first_antiprompt
]:
break
# if we are using instruction mode, and we have processed the initial prompt # if we are using instruction mode, and we have processed the initial prompt
if (self.init_break): if (self.init_break):
self.init_break = False
break break
# if end of generation # if end of generation
@ -201,6 +201,8 @@ class LLaMAInteract:
self.embd_inp += self.first_antiprompt[0] self.embd_inp += self.first_antiprompt[0]
break break
self.init_break = False
def __enter__(self): def __enter__(self):
return self return self