Add instruction mode

This commit is contained in:
Mug 2023-04-04 11:48:48 +02:00
parent f1615f05e6
commit 0b32bb3d43

View file

@ -5,24 +5,26 @@ Quirks:
* Input is always echoed if on, so it should be turned off when using "input()" * Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:", * The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely) because its added when n_predict is reached (aka generation ended prematurely)
* n_predict can be set to -1 for unlimited length responses * n_predict can be set to -1 for unlimited length responses (or just a really high value)
* It's always in interactive mode, generation ends either by reaching an antiprompt
or running out of n_predict.
* Instruction mode adds its own antiprompt
""" """
import llama_cpp import llama_cpp
def toIntArray(lst):
return [int(i) for i in lst]
# A LLaMA interactive session # A LLaMA interactive session
class LLaMAInteract: class LLaMAInteract:
def __init__(self, def __init__(self,
primer: str="", primer: str="",
model: str="./models/30B/ggml-model-q4_0.bin", model: str="./models/30B/ggml-model-q4_0.bin",
instruct: bool=False,
n_ctx: int=1024, n_ctx: int=1024,
seed: int=0, seed: int=0,
n_threads: int=8, n_threads: int=8,
antiprompt: list[str]=[], antiprompt: list[str]=[],
input_echo: bool=True, input_echo: bool=True,
n_predict: int=20, n_predict: int=20,
n_keep: int=0,
n_batch: int=8, n_batch: int=8,
repeat_last_n: int=64, repeat_last_n: int=64,
top_k: int=50, top_k: int=50,
@ -31,17 +33,17 @@ class LLaMAInteract:
repeat_penalty: float=1, repeat_penalty: float=1,
) -> None: ) -> None:
# input args # input args
self.instruct = instruct
self.n_threads = n_threads self.n_threads = n_threads
self.input_echo = input_echo self.input_echo = input_echo
self.n_predict = n_predict self.n_predict = n_predict
self.n_keep = n_keep
self.n_batch = n_batch self.n_batch = n_batch
self.repeat_last_n = repeat_last_n self.repeat_last_n = repeat_last_n
self.top_k=top_k self.top_k=top_k
self.top_p=top_p self.top_p=top_p
self.temp=temp self.temp=temp
self.repeat_penalty=repeat_penalty self.repeat_penalty=repeat_penalty
self.n_ctx = n_ctx
self.seed = seed
# runtime args # runtime args
self.input_consumed = 0 self.input_consumed = 0
@ -54,8 +56,8 @@ class LLaMAInteract:
# model load # model load
self.lparams = llama_cpp.llama_context_default_params() self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = self.n_ctx self.lparams.n_ctx = n_ctx
self.lparams.seed = self.seed self.lparams.seed = seed
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams) self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
# determine the required inference memory per token: # determine the required inference memory per token:
@ -63,29 +65,44 @@ class LLaMAInteract:
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
# determine newline token # determine newline token
self.llama_token_newline = (llama_cpp.llama_token * 1)() self.llama_token_newline = self._tokenize("\n", False)
llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False) self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
self.llama_token_newline = toIntArray(self.llama_token_newline) self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
# add instruction as antiprompt
if (self.instruct):
self.first_antiprompt.append(self.inp_prefix)
# primer feed # primer feed
if (len(primer) > 0): if (len(primer) > 0):
self.input(primer) self.embd_inp += self._tokenize(primer)
# break immediately if using instruct
self.init_break = self.instruct
# number of tokens to keep when resetting context
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
self.n_keep = len(self.embd_inp) self.n_keep = len(self.embd_inp)
# create internal context # create internal context
self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
# determine antiprompt tokens # determine antiprompt tokens
for i in antiprompt: for i in antiprompt:
d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))() self.first_antiprompt.append(self._tokenize(i, False))
n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False)
self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt])) # tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n]
# if an antiprompt is present # if an antiprompt is present
def use_antiprompt(self): def use_antiprompt(self):
return len(self.first_antiprompt) > 0 return len(self.first_antiprompt) > 0
# generate tokens
def generate(self): def generate(self):
while self.remaining_tokens > 0 or self.use_antiprompt(): while self.remaining_tokens > 0 or self.use_antiprompt():
# predict # predict
@ -125,16 +142,16 @@ class LLaMAInteract:
self.repeat_penalty, self.repeat_penalty,
) )
self.last_n_tokens.pop(0) self.last_n_tokens.pop(0)
self.last_n_tokens.append(int(id)) self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode # replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()): if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct):
id = self.llama_token_newline[0] id = self.llama_token_newline[0]
# tokenize and inject first reverse prompt # tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0] self.embd_inp += self.first_antiprompt[0]
# add it to the context # add it to the context
self.embd.append(int(id)) self.embd.append(id)
# echo this to console # echo this to console
self.output_echo = True self.output_echo = True
@ -147,9 +164,9 @@ class LLaMAInteract:
# some user input remains from prompt or interaction, forward it to processing # some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed: while len(self.embd_inp) > self.input_consumed:
self.embd.append(int(self.embd_inp[self.input_consumed])) self.embd.append(self.embd_inp[self.input_consumed])
self.last_n_tokens.pop(0) self.last_n_tokens.pop(0)
self.last_n_tokens.append(int(self.embd_inp[self.input_consumed])) self.last_n_tokens.append(self.embd_inp[self.input_consumed])
self.input_consumed += 1 self.input_consumed += 1
if len(self.embd) >= self.n_batch: if len(self.embd) >= self.n_batch:
break break
@ -159,12 +176,18 @@ class LLaMAInteract:
for id in self.embd: for id in self.embd:
yield id yield id
if (len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop # if antiprompt is present, stop
if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed): if (self.use_antiprompt()):
for i in self.first_antiprompt: for i in self.first_antiprompt:
if i == self.last_n_tokens[-len(i):]: if i == self.last_n_tokens[-len(i):]:
return return
# if we are using instruction mode, and we have processed the initial prompt
if (self.init_break):
self.init_break = False
break
# if end of generation # if end of generation
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
break break
@ -174,15 +197,20 @@ class LLaMAInteract:
self.embd_inp += self.first_antiprompt[0] self.embd_inp += self.first_antiprompt[0]
break break
# return past text
def past(self): def past(self):
for id in self.last_n_tokens[-self.n_past:]: for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# write input
def input(self, prompt: str): def input(self, prompt: str):
embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))() if (self.instruct):
n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True) self.embd_inp += self.inp_prefix
self.embd_inp += toIntArray(embd_arr[:n_of_tok]) self.embd_inp += self._tokenize(prompt + "\n")
if (self.instruct):
self.embd_inp += self.inp_suffix
# write output
def output(self): def output(self):
self.remaining_tokens = self.n_predict self.remaining_tokens = self.n_predict
for id in self.generate(): for id in self.generate():
@ -214,7 +242,7 @@ The transcript only includes text, it does not include markup like HTML and Mark
{USER_NAME}:""" {USER_NAME}:"""
print("Loading model...") print("Loading model...")
ll = LLaMAInteract(prompt, m = LLaMAInteract(prompt,
model="./models/30B/ggml-model-q4_0.bin", model="./models/30B/ggml-model-q4_0.bin",
n_ctx=2048, n_ctx=2048,
antiprompt=[f"\n{USER_NAME}:"], antiprompt=[f"\n{USER_NAME}:"],
@ -224,12 +252,11 @@ The transcript only includes text, it does not include markup like HTML and Mark
) )
print("Loaded model!") print("Loaded model!")
for i in ll.output(): for i in m.output():
print(i,end="",flush=True) print(i,end="",flush=True)
ll.input_echo = False m.input_echo = False
inp = lambda x: f" {x}\n"
while True: while True:
ll.input(inp(input(' '))) m.input(" " + input('\n> ' if m.instruct else " "))
for i in ll.output(): for i in m.output():
print(i,end="",flush=True) print(i,end="",flush=True)