Added instruction mode, fixed infinite generation, and various other fixes
This commit is contained in:
parent
0b32bb3d43
commit
da5a6a7089
1 changed files with 44 additions and 18 deletions
|
@ -8,7 +8,9 @@ Quirks:
|
||||||
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
|
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
|
||||||
* It's always in interactive mode, generation ends either by reaching an antiprompt
|
* It's always in interactive mode, generation ends either by reaching an antiprompt
|
||||||
or running out of n_predict.
|
or running out of n_predict.
|
||||||
* Instruction mode adds its own antiprompt
|
* Instruction mode adds its own antiprompt.
|
||||||
|
You should also still be feeding the model with a "primer" prompt that
|
||||||
|
shows it the expected format.
|
||||||
"""
|
"""
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
|
@ -31,6 +33,8 @@ class LLaMAInteract:
|
||||||
top_p: float=1.,
|
top_p: float=1.,
|
||||||
temp: float=1.0,
|
temp: float=1.0,
|
||||||
repeat_penalty: float=1,
|
repeat_penalty: float=1,
|
||||||
|
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
|
||||||
|
instruct_inp_suffix: str="\n\n### Response:\n\n",
|
||||||
) -> None:
|
) -> None:
|
||||||
# input args
|
# input args
|
||||||
self.instruct = instruct
|
self.instruct = instruct
|
||||||
|
@ -66,12 +70,12 @@ class LLaMAInteract:
|
||||||
|
|
||||||
# determine newline token
|
# determine newline token
|
||||||
self.llama_token_newline = self._tokenize("\n", False)
|
self.llama_token_newline = self._tokenize("\n", False)
|
||||||
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
|
self.inp_prefix = self._tokenize(instruct_inp_prefix)
|
||||||
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
|
self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
|
||||||
|
|
||||||
# add instruction as antiprompt
|
# add instruction as antiprompt
|
||||||
if (self.instruct):
|
if (self.instruct):
|
||||||
self.first_antiprompt.append(self.inp_prefix)
|
self.first_antiprompt.append(self.inp_prefix.strip())
|
||||||
|
|
||||||
# primer feed
|
# primer feed
|
||||||
if (len(primer) > 0):
|
if (len(primer) > 0):
|
||||||
|
@ -117,10 +121,9 @@ class LLaMAInteract:
|
||||||
|
|
||||||
# insert n_left/2 tokens at the start of embd from last_n_tokens
|
# insert n_left/2 tokens at the start of embd from last_n_tokens
|
||||||
_insert = self.last_n_tokens[
|
_insert = self.last_n_tokens[
|
||||||
-(int(n_left/2) - len(self.embd)):-len(self.embd)
|
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
|
||||||
]
|
]
|
||||||
self.embd[:len(_insert)] = _insert
|
self.embd = _insert + self.embd
|
||||||
#TODO: Still untested
|
|
||||||
|
|
||||||
if (llama_cpp.llama_eval(
|
if (llama_cpp.llama_eval(
|
||||||
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
|
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
|
||||||
|
@ -197,6 +200,12 @@ class LLaMAInteract:
|
||||||
self.embd_inp += self.first_antiprompt[0]
|
self.embd_inp += self.first_antiprompt[0]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, tb):
|
||||||
|
llama_cpp.llama_free(self.ctx)
|
||||||
|
|
||||||
# return past text
|
# return past text
|
||||||
def past(self):
|
def past(self):
|
||||||
for id in self.last_n_tokens[-self.n_past:]:
|
for id in self.last_n_tokens[-self.n_past:]:
|
||||||
|
@ -206,7 +215,7 @@ class LLaMAInteract:
|
||||||
def input(self, prompt: str):
|
def input(self, prompt: str):
|
||||||
if (self.instruct):
|
if (self.instruct):
|
||||||
self.embd_inp += self.inp_prefix
|
self.embd_inp += self.inp_prefix
|
||||||
self.embd_inp += self._tokenize(prompt + "\n")
|
self.embd_inp += self._tokenize(prompt)
|
||||||
if (self.instruct):
|
if (self.instruct):
|
||||||
self.embd_inp += self.inp_suffix
|
self.embd_inp += self.inp_suffix
|
||||||
|
|
||||||
|
@ -242,21 +251,38 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||||
{USER_NAME}:"""
|
{USER_NAME}:"""
|
||||||
|
|
||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
m = LLaMAInteract(prompt,
|
with LLaMAInteract(prompt,
|
||||||
model="./models/30B/ggml-model-q4_0.bin",
|
model="./models/30B/ggml-model-q4_0.bin",
|
||||||
n_ctx=2048,
|
n_ctx=2048,
|
||||||
antiprompt=[f"\n{USER_NAME}:"],
|
antiprompt=[f"\n{USER_NAME}:"],
|
||||||
repeat_last_n=256,
|
repeat_last_n=256,
|
||||||
n_predict=2048,
|
n_predict=2048,
|
||||||
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
|
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
|
||||||
)
|
) as m:
|
||||||
print("Loaded model!")
|
print("Loaded model!")
|
||||||
|
|
||||||
for i in m.output():
|
|
||||||
print(i,end="",flush=True)
|
|
||||||
m.input_echo = False
|
|
||||||
|
|
||||||
while True:
|
|
||||||
m.input(" " + input('\n> ' if m.instruct else " "))
|
|
||||||
for i in m.output():
|
for i in m.output():
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
||||||
|
m.input_echo = False
|
||||||
|
|
||||||
|
def inp():
|
||||||
|
out = ""
|
||||||
|
while (t := input()).endswith("\\"):
|
||||||
|
out += t[:-1] + "\n"
|
||||||
|
return out + t + "\n"
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if (m.instruct):
|
||||||
|
print('\n> ', end="")
|
||||||
|
m.input(inp())
|
||||||
|
else:
|
||||||
|
print(f" ", end="")
|
||||||
|
m.input(f" {inp()}{AI_NAME}:")
|
||||||
|
print(f"{AI_NAME}: ",end="")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in m.output():
|
||||||
|
print(i,end="",flush=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n{USER_NAME}:",end="")
|
||||||
|
m.input(f"\n{USER_NAME}:")
|
||||||
|
|
Loading…
Add table
Reference in a new issue