Added instruction mode, fixed infinite generation, and various other fixes
This commit is contained in:
parent
0b32bb3d43
commit
da5a6a7089
1 changed files with 44 additions and 18 deletions
|
@ -8,7 +8,9 @@ Quirks:
|
|||
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
|
||||
* It's always in interactive mode, generation ends either by reaching an antiprompt
|
||||
or running out of n_predict.
|
||||
* Instruction mode adds its own antiprompt
|
||||
* Instruction mode adds its own antiprompt.
|
||||
You should also still be feeding the model with a "primer" prompt that
|
||||
shows it the expected format.
|
||||
"""
|
||||
import llama_cpp
|
||||
|
||||
|
@ -31,6 +33,8 @@ class LLaMAInteract:
|
|||
top_p: float=1.,
|
||||
temp: float=1.0,
|
||||
repeat_penalty: float=1,
|
||||
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
|
||||
instruct_inp_suffix: str="\n\n### Response:\n\n",
|
||||
) -> None:
|
||||
# input args
|
||||
self.instruct = instruct
|
||||
|
@ -66,12 +70,12 @@ class LLaMAInteract:
|
|||
|
||||
# determine newline token
|
||||
self.llama_token_newline = self._tokenize("\n", False)
|
||||
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
|
||||
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
|
||||
self.inp_prefix = self._tokenize(instruct_inp_prefix)
|
||||
self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
|
||||
|
||||
# add instruction as antiprompt
|
||||
if (self.instruct):
|
||||
self.first_antiprompt.append(self.inp_prefix)
|
||||
self.first_antiprompt.append(self.inp_prefix.strip())
|
||||
|
||||
# primer feed
|
||||
if (len(primer) > 0):
|
||||
|
@ -117,10 +121,9 @@ class LLaMAInteract:
|
|||
|
||||
# insert n_left/2 tokens at the start of embd from last_n_tokens
|
||||
_insert = self.last_n_tokens[
|
||||
-(int(n_left/2) - len(self.embd)):-len(self.embd)
|
||||
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
|
||||
]
|
||||
self.embd[:len(_insert)] = _insert
|
||||
#TODO: Still untested
|
||||
self.embd = _insert + self.embd
|
||||
|
||||
if (llama_cpp.llama_eval(
|
||||
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
|
||||
|
@ -197,6 +200,12 @@ class LLaMAInteract:
|
|||
self.embd_inp += self.first_antiprompt[0]
|
||||
break
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
|
||||
# return past text
|
||||
def past(self):
|
||||
for id in self.last_n_tokens[-self.n_past:]:
|
||||
|
@ -206,7 +215,7 @@ class LLaMAInteract:
|
|||
def input(self, prompt: str):
|
||||
if (self.instruct):
|
||||
self.embd_inp += self.inp_prefix
|
||||
self.embd_inp += self._tokenize(prompt + "\n")
|
||||
self.embd_inp += self._tokenize(prompt)
|
||||
if (self.instruct):
|
||||
self.embd_inp += self.inp_suffix
|
||||
|
||||
|
@ -242,21 +251,38 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
|||
{USER_NAME}:"""
|
||||
|
||||
print("Loading model...")
|
||||
m = LLaMAInteract(prompt,
|
||||
with LLaMAInteract(prompt,
|
||||
model="./models/30B/ggml-model-q4_0.bin",
|
||||
n_ctx=2048,
|
||||
antiprompt=[f"\n{USER_NAME}:"],
|
||||
repeat_last_n=256,
|
||||
n_predict=2048,
|
||||
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
|
||||
)
|
||||
print("Loaded model!")
|
||||
) as m:
|
||||
print("Loaded model!")
|
||||
|
||||
for i in m.output():
|
||||
print(i,end="",flush=True)
|
||||
m.input_echo = False
|
||||
|
||||
while True:
|
||||
m.input(" " + input('\n> ' if m.instruct else " "))
|
||||
for i in m.output():
|
||||
print(i,end="",flush=True)
|
||||
print(i,end="",flush=True)
|
||||
m.input_echo = False
|
||||
|
||||
def inp():
|
||||
out = ""
|
||||
while (t := input()).endswith("\\"):
|
||||
out += t[:-1] + "\n"
|
||||
return out + t + "\n"
|
||||
|
||||
while True:
|
||||
if (m.instruct):
|
||||
print('\n> ', end="")
|
||||
m.input(inp())
|
||||
else:
|
||||
print(f" ", end="")
|
||||
m.input(f" {inp()}{AI_NAME}:")
|
||||
print(f"{AI_NAME}: ",end="")
|
||||
|
||||
try:
|
||||
for i in m.output():
|
||||
print(i,end="",flush=True)
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n{USER_NAME}:",end="")
|
||||
m.input(f"\n{USER_NAME}:")
|
||||
|
|
Loading…
Reference in a new issue