Add instruction mode
This commit is contained in:
parent
f1615f05e6
commit
0b32bb3d43
1 changed files with 64 additions and 37 deletions
|
@ -5,24 +5,26 @@ Quirks:
|
||||||
* Input is always echoed if on, so it should be turned off when using "input()"
|
* Input is always echoed if on, so it should be turned off when using "input()"
|
||||||
* The first antiprompt should be the userprompt like "\nUser:",
|
* The first antiprompt should be the userprompt like "\nUser:",
|
||||||
because its added when n_predict is reached (aka generation ended prematurely)
|
because its added when n_predict is reached (aka generation ended prematurely)
|
||||||
* n_predict can be set to -1 for unlimited length responses
|
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
|
||||||
|
* It's always in interactive mode, generation ends either by reaching an antiprompt
|
||||||
|
or running out of n_predict.
|
||||||
|
* Instruction mode adds its own antiprompt
|
||||||
"""
|
"""
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
def toIntArray(lst):
|
|
||||||
return [int(i) for i in lst]
|
|
||||||
|
|
||||||
# A LLaMA interactive session
|
# A LLaMA interactive session
|
||||||
class LLaMAInteract:
|
class LLaMAInteract:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
primer: str="",
|
primer: str="",
|
||||||
model: str="./models/30B/ggml-model-q4_0.bin",
|
model: str="./models/30B/ggml-model-q4_0.bin",
|
||||||
|
instruct: bool=False,
|
||||||
n_ctx: int=1024,
|
n_ctx: int=1024,
|
||||||
seed: int=0,
|
seed: int=0,
|
||||||
n_threads: int=8,
|
n_threads: int=8,
|
||||||
antiprompt: list[str]=[],
|
antiprompt: list[str]=[],
|
||||||
input_echo: bool=True,
|
input_echo: bool=True,
|
||||||
n_predict: int=20,
|
n_predict: int=20,
|
||||||
|
n_keep: int=0,
|
||||||
n_batch: int=8,
|
n_batch: int=8,
|
||||||
repeat_last_n: int=64,
|
repeat_last_n: int=64,
|
||||||
top_k: int=50,
|
top_k: int=50,
|
||||||
|
@ -31,17 +33,17 @@ class LLaMAInteract:
|
||||||
repeat_penalty: float=1,
|
repeat_penalty: float=1,
|
||||||
) -> None:
|
) -> None:
|
||||||
# input args
|
# input args
|
||||||
|
self.instruct = instruct
|
||||||
self.n_threads = n_threads
|
self.n_threads = n_threads
|
||||||
self.input_echo = input_echo
|
self.input_echo = input_echo
|
||||||
self.n_predict = n_predict
|
self.n_predict = n_predict
|
||||||
|
self.n_keep = n_keep
|
||||||
self.n_batch = n_batch
|
self.n_batch = n_batch
|
||||||
self.repeat_last_n = repeat_last_n
|
self.repeat_last_n = repeat_last_n
|
||||||
self.top_k=top_k
|
self.top_k=top_k
|
||||||
self.top_p=top_p
|
self.top_p=top_p
|
||||||
self.temp=temp
|
self.temp=temp
|
||||||
self.repeat_penalty=repeat_penalty
|
self.repeat_penalty=repeat_penalty
|
||||||
self.n_ctx = n_ctx
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
# runtime args
|
# runtime args
|
||||||
self.input_consumed = 0
|
self.input_consumed = 0
|
||||||
|
@ -54,8 +56,8 @@ class LLaMAInteract:
|
||||||
|
|
||||||
# model load
|
# model load
|
||||||
self.lparams = llama_cpp.llama_context_default_params()
|
self.lparams = llama_cpp.llama_context_default_params()
|
||||||
self.lparams.n_ctx = self.n_ctx
|
self.lparams.n_ctx = n_ctx
|
||||||
self.lparams.seed = self.seed
|
self.lparams.seed = seed
|
||||||
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
|
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
|
||||||
|
|
||||||
# determine the required inference memory per token:
|
# determine the required inference memory per token:
|
||||||
|
@ -63,29 +65,44 @@ class LLaMAInteract:
|
||||||
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
|
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
|
||||||
|
|
||||||
# determine newline token
|
# determine newline token
|
||||||
self.llama_token_newline = (llama_cpp.llama_token * 1)()
|
self.llama_token_newline = self._tokenize("\n", False)
|
||||||
llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False)
|
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
|
||||||
self.llama_token_newline = toIntArray(self.llama_token_newline)
|
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
|
||||||
|
|
||||||
|
# add instruction as antiprompt
|
||||||
|
if (self.instruct):
|
||||||
|
self.first_antiprompt.append(self.inp_prefix)
|
||||||
|
|
||||||
# primer feed
|
# primer feed
|
||||||
if (len(primer) > 0):
|
if (len(primer) > 0):
|
||||||
self.input(primer)
|
self.embd_inp += self._tokenize(primer)
|
||||||
|
|
||||||
|
# break immediately if using instruct
|
||||||
|
self.init_break = self.instruct
|
||||||
|
|
||||||
|
# number of tokens to keep when resetting context
|
||||||
|
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
|
||||||
self.n_keep = len(self.embd_inp)
|
self.n_keep = len(self.embd_inp)
|
||||||
|
|
||||||
# create internal context
|
# create internal context
|
||||||
self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
|
||||||
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
|
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
|
||||||
|
|
||||||
# determine antiprompt tokens
|
# determine antiprompt tokens
|
||||||
for i in antiprompt:
|
for i in antiprompt:
|
||||||
d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))()
|
self.first_antiprompt.append(self._tokenize(i, False))
|
||||||
n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False)
|
|
||||||
self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt]))
|
# tokenize a prompt
|
||||||
|
def _tokenize(self, prompt, bos=True):
|
||||||
|
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
||||||
|
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
|
||||||
|
return _arr[:_n]
|
||||||
|
|
||||||
# if an antiprompt is present
|
# if an antiprompt is present
|
||||||
def use_antiprompt(self):
|
def use_antiprompt(self):
|
||||||
return len(self.first_antiprompt) > 0
|
return len(self.first_antiprompt) > 0
|
||||||
|
|
||||||
|
# generate tokens
|
||||||
def generate(self):
|
def generate(self):
|
||||||
while self.remaining_tokens > 0 or self.use_antiprompt():
|
while self.remaining_tokens > 0 or self.use_antiprompt():
|
||||||
# predict
|
# predict
|
||||||
|
@ -125,16 +142,16 @@ class LLaMAInteract:
|
||||||
self.repeat_penalty,
|
self.repeat_penalty,
|
||||||
)
|
)
|
||||||
self.last_n_tokens.pop(0)
|
self.last_n_tokens.pop(0)
|
||||||
self.last_n_tokens.append(int(id))
|
self.last_n_tokens.append(id)
|
||||||
|
|
||||||
# replace end of text token with newline token when in interactive mode
|
# replace end of text token with newline token when in interactive mode
|
||||||
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()):
|
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct):
|
||||||
id = self.llama_token_newline[0]
|
id = self.llama_token_newline[0]
|
||||||
# tokenize and inject first reverse prompt
|
# tokenize and inject first reverse prompt
|
||||||
self.embd_inp += self.first_antiprompt[0]
|
self.embd_inp += self.first_antiprompt[0]
|
||||||
|
|
||||||
# add it to the context
|
# add it to the context
|
||||||
self.embd.append(int(id))
|
self.embd.append(id)
|
||||||
|
|
||||||
# echo this to console
|
# echo this to console
|
||||||
self.output_echo = True
|
self.output_echo = True
|
||||||
|
@ -147,9 +164,9 @@ class LLaMAInteract:
|
||||||
|
|
||||||
# some user input remains from prompt or interaction, forward it to processing
|
# some user input remains from prompt or interaction, forward it to processing
|
||||||
while len(self.embd_inp) > self.input_consumed:
|
while len(self.embd_inp) > self.input_consumed:
|
||||||
self.embd.append(int(self.embd_inp[self.input_consumed]))
|
self.embd.append(self.embd_inp[self.input_consumed])
|
||||||
self.last_n_tokens.pop(0)
|
self.last_n_tokens.pop(0)
|
||||||
self.last_n_tokens.append(int(self.embd_inp[self.input_consumed]))
|
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
|
||||||
self.input_consumed += 1
|
self.input_consumed += 1
|
||||||
if len(self.embd) >= self.n_batch:
|
if len(self.embd) >= self.n_batch:
|
||||||
break
|
break
|
||||||
|
@ -159,12 +176,18 @@ class LLaMAInteract:
|
||||||
for id in self.embd:
|
for id in self.embd:
|
||||||
yield id
|
yield id
|
||||||
|
|
||||||
|
if (len(self.embd_inp) <= self.input_consumed):
|
||||||
# if antiprompt is present, stop
|
# if antiprompt is present, stop
|
||||||
if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed):
|
if (self.use_antiprompt()):
|
||||||
for i in self.first_antiprompt:
|
for i in self.first_antiprompt:
|
||||||
if i == self.last_n_tokens[-len(i):]:
|
if i == self.last_n_tokens[-len(i):]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# if we are using instruction mode, and we have processed the initial prompt
|
||||||
|
if (self.init_break):
|
||||||
|
self.init_break = False
|
||||||
|
break
|
||||||
|
|
||||||
# if end of generation
|
# if end of generation
|
||||||
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
||||||
break
|
break
|
||||||
|
@ -174,15 +197,20 @@ class LLaMAInteract:
|
||||||
self.embd_inp += self.first_antiprompt[0]
|
self.embd_inp += self.first_antiprompt[0]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# return past text
|
||||||
def past(self):
|
def past(self):
|
||||||
for id in self.last_n_tokens[-self.n_past:]:
|
for id in self.last_n_tokens[-self.n_past:]:
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
||||||
|
|
||||||
|
# write input
|
||||||
def input(self, prompt: str):
|
def input(self, prompt: str):
|
||||||
embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
if (self.instruct):
|
||||||
n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True)
|
self.embd_inp += self.inp_prefix
|
||||||
self.embd_inp += toIntArray(embd_arr[:n_of_tok])
|
self.embd_inp += self._tokenize(prompt + "\n")
|
||||||
|
if (self.instruct):
|
||||||
|
self.embd_inp += self.inp_suffix
|
||||||
|
|
||||||
|
# write output
|
||||||
def output(self):
|
def output(self):
|
||||||
self.remaining_tokens = self.n_predict
|
self.remaining_tokens = self.n_predict
|
||||||
for id in self.generate():
|
for id in self.generate():
|
||||||
|
@ -214,7 +242,7 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||||
{USER_NAME}:"""
|
{USER_NAME}:"""
|
||||||
|
|
||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
ll = LLaMAInteract(prompt,
|
m = LLaMAInteract(prompt,
|
||||||
model="./models/30B/ggml-model-q4_0.bin",
|
model="./models/30B/ggml-model-q4_0.bin",
|
||||||
n_ctx=2048,
|
n_ctx=2048,
|
||||||
antiprompt=[f"\n{USER_NAME}:"],
|
antiprompt=[f"\n{USER_NAME}:"],
|
||||||
|
@ -224,12 +252,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||||
)
|
)
|
||||||
print("Loaded model!")
|
print("Loaded model!")
|
||||||
|
|
||||||
for i in ll.output():
|
for i in m.output():
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
||||||
ll.input_echo = False
|
m.input_echo = False
|
||||||
|
|
||||||
inp = lambda x: f" {x}\n"
|
|
||||||
while True:
|
while True:
|
||||||
ll.input(inp(input(' ')))
|
m.input(" " + input('\n> ' if m.instruct else " "))
|
||||||
for i in ll.output():
|
for i in m.output():
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
Loading…
Reference in a new issue