Added instruction mode, fixed infinite generation, and various other fixes

This commit is contained in:
Mug 2023-04-04 16:18:26 +02:00
parent 0b32bb3d43
commit da5a6a7089

View file

@ -8,7 +8,9 @@ Quirks:
* n_predict can be set to -1 for unlimited length responses (or just a really high value) * n_predict can be set to -1 for unlimited length responses (or just a really high value)
* It's always in interactive mode, generation ends either by reaching an antiprompt * It's always in interactive mode, generation ends either by reaching an antiprompt
or running out of n_predict. or running out of n_predict.
* Instruction mode adds its own antiprompt * Instruction mode adds its own antiprompt.
You should also still be feeding the model with a "primer" prompt that
shows it the expected format.
""" """
import llama_cpp import llama_cpp
@ -31,6 +33,8 @@ class LLaMAInteract:
top_p: float=1., top_p: float=1.,
temp: float=1.0, temp: float=1.0,
repeat_penalty: float=1, repeat_penalty: float=1,
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
instruct_inp_suffix: str="\n\n### Response:\n\n",
) -> None: ) -> None:
# input args # input args
self.instruct = instruct self.instruct = instruct
@ -66,12 +70,12 @@ class LLaMAInteract:
# determine newline token # determine newline token
self.llama_token_newline = self._tokenize("\n", False) self.llama_token_newline = self._tokenize("\n", False)
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n") self.inp_prefix = self._tokenize(instruct_inp_prefix)
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False) self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
# add instruction as antiprompt # add instruction as antiprompt
if (self.instruct): if (self.instruct):
self.first_antiprompt.append(self.inp_prefix) self.first_antiprompt.append(self.inp_prefix.strip())
# primer feed # primer feed
if (len(primer) > 0): if (len(primer) > 0):
@ -117,10 +121,9 @@ class LLaMAInteract:
# insert n_left/2 tokens at the start of embd from last_n_tokens # insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[ _insert = self.last_n_tokens[
-(int(n_left/2) - len(self.embd)):-len(self.embd) self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
] ]
self.embd[:len(_insert)] = _insert self.embd = _insert + self.embd
#TODO: Still untested
if (llama_cpp.llama_eval( if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
@ -197,6 +200,12 @@ class LLaMAInteract:
self.embd_inp += self.first_antiprompt[0] self.embd_inp += self.first_antiprompt[0]
break break
def __enter__(self):
return self
def __exit__(self, type, value, tb):
llama_cpp.llama_free(self.ctx)
# return past text # return past text
def past(self): def past(self):
for id in self.last_n_tokens[-self.n_past:]: for id in self.last_n_tokens[-self.n_past:]:
@ -206,7 +215,7 @@ class LLaMAInteract:
def input(self, prompt: str): def input(self, prompt: str):
if (self.instruct): if (self.instruct):
self.embd_inp += self.inp_prefix self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt + "\n") self.embd_inp += self._tokenize(prompt)
if (self.instruct): if (self.instruct):
self.embd_inp += self.inp_suffix self.embd_inp += self.inp_suffix
@ -242,21 +251,38 @@ The transcript only includes text, it does not include markup like HTML and Mark
{USER_NAME}:""" {USER_NAME}:"""
print("Loading model...") print("Loading model...")
m = LLaMAInteract(prompt, with LLaMAInteract(prompt,
model="./models/30B/ggml-model-q4_0.bin", model="./models/30B/ggml-model-q4_0.bin",
n_ctx=2048, n_ctx=2048,
antiprompt=[f"\n{USER_NAME}:"], antiprompt=[f"\n{USER_NAME}:"],
repeat_last_n=256, repeat_last_n=256,
n_predict=2048, n_predict=2048,
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647 temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
) ) as m:
print("Loaded model!") print("Loaded model!")
for i in m.output(): for i in m.output():
print(i,end="",flush=True) print(i,end="",flush=True)
m.input_echo = False m.input_echo = False
def inp():
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
while True: while True:
m.input(" " + input('\n> ' if m.instruct else " ")) if (m.instruct):
print('\n> ', end="")
m.input(inp())
else:
print(f" ", end="")
m.input(f" {inp()}{AI_NAME}:")
print(f"{AI_NAME}: ",end="")
try:
for i in m.output(): for i in m.output():
print(i,end="",flush=True) print(i,end="",flush=True)
except KeyboardInterrupt:
print(f"\n{USER_NAME}:",end="")
m.input(f"\n{USER_NAME}:")