Added iterative search to prevent instructions from being echoed, add ignore eos, add no-mmap, fixed 1 character echo too much bug

This commit is contained in:
Mug 2023-04-10 16:35:38 +02:00
parent 241d608bbb
commit 0cccb41a8f
2 changed files with 35 additions and 4 deletions

View file

@ -40,6 +40,7 @@ class GptParams:
instruct: bool = False
ignore_eos: bool = False
perplexity: bool = False
use_mmap: bool = True
use_mlock: bool = False
mem_test: bool = False
verbose_prompt: bool = False
@ -110,7 +111,9 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
dest="use_color"
)
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")
parser.add_argument(
"-r",
"--reverse-prompt",

View file

@ -26,6 +26,25 @@ CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
# Iterative search
# Actively searches and prevents a pattern from being returned
class IterSearch:
def __init__(self, pattern):
self.pattern = list(pattern)
self.buffer = []
def __call__(self, char):
self.buffer += [char]
if (self.pattern[:len(self.buffer)] == self.buffer):
if (len(self.buffer) >= len(self.pattern)):
self.buffer.clear()
return []
_tmp = self.buffer[:]
self.buffer.clear()
return _tmp
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self, params: GptParams) -> None:
@ -69,6 +88,7 @@ specified) expect poor results""", file=sys.stderr)
self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.lparams.use_mmap = self.params.use_mmap
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
if (not self.ctx):
@ -114,7 +134,9 @@ specified) expect poor results""", file=sys.stderr)
# in instruct mode, we inject a prefix and a suffix to each input by the user
if (self.params.instruct):
self.params.interactive_start = True
self.first_antiprompt.append(self._tokenize(self.params.instruct_inp_prefix.strip(), False))
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
self.first_antiprompt.append(_ptn)
self.antiecho = IterSearch(_ptn)
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
@ -217,7 +239,9 @@ n_keep = {self.params.n_keep}
if len(self.embd_inp) <= self.input_consumed:
# out of user input, sample next token
#TODO: self.params.ignore_eos
if (self.params.ignore_eos):
logits = llama_cpp.llama_get_logits(self.ctx)
logits[llama_cpp.llama_token_eos()] = llama_cpp.c_float(0)
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k(
@ -263,7 +287,11 @@ n_keep = {self.params.n_keep}
# display tokens
if self.output_echo:
for id in self.embd:
yield id
if self.params.instruct:
for r in self.antiecho(id):
yield r
else:
yield id
# reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
@ -279,7 +307,7 @@ n_keep = {self.params.n_keep}
break
# if we are using instruction mode, and we have processed the initial prompt
if (self.n_past > 0 and self.params.interactive_start):
if (self.params.interactive_start):
break
# end of text token