Add utf8 to chat example
This commit is contained in:
parent
3ceb47b597
commit
996f63e9e1
3 changed files with 130 additions and 40 deletions
|
@ -102,7 +102,7 @@ def gpt_params_parse(argv = None):
|
||||||
parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
|
parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
|
||||||
parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
|
parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
|
||||||
parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
|
parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
|
||||||
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau",dest="mirostat_tau")
|
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau represents the average surprise value",dest="mirostat_tau")
|
||||||
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
|
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
|
||||||
|
|
||||||
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
||||||
|
|
|
@ -17,34 +17,7 @@ from os import cpu_count, path
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
from common import GptParams, gpt_params_parse, gpt_random_prompt
|
from common import GptParams, gpt_params_parse, gpt_random_prompt
|
||||||
|
import util
|
||||||
ANSI_COLOR_RESET = "\x1b[0m"
|
|
||||||
ANSI_COLOR_YELLOW = "\x1b[33m"
|
|
||||||
ANSI_BOLD = "\x1b[1m"
|
|
||||||
ANSI_COLOR_GREEN = "\x1b[32m"
|
|
||||||
|
|
||||||
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
|
|
||||||
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
|
|
||||||
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
|
|
||||||
|
|
||||||
# Iterative search
|
|
||||||
# Actively searches and prevents a pattern from being returned
|
|
||||||
class IterSearch:
|
|
||||||
def __init__(self, pattern):
|
|
||||||
self.pattern = list(pattern)
|
|
||||||
self.buffer = []
|
|
||||||
|
|
||||||
def __call__(self, char):
|
|
||||||
self.buffer += [char]
|
|
||||||
|
|
||||||
if (self.pattern[:len(self.buffer)] == self.buffer):
|
|
||||||
if (len(self.buffer) >= len(self.pattern)):
|
|
||||||
self.buffer.clear()
|
|
||||||
return []
|
|
||||||
|
|
||||||
_tmp = self.buffer[:]
|
|
||||||
self.buffer.clear()
|
|
||||||
return _tmp
|
|
||||||
|
|
||||||
# A LLaMA interactive session
|
# A LLaMA interactive session
|
||||||
class LLaMAInteract:
|
class LLaMAInteract:
|
||||||
|
@ -82,6 +55,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
self.first_antiprompt = []
|
self.first_antiprompt = []
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
self.output_echo = self.params.input_echo
|
self.output_echo = self.params.input_echo
|
||||||
|
self.multibyte_fix = []
|
||||||
|
|
||||||
# model load
|
# model load
|
||||||
self.lparams = llama_cpp.llama_context_default_params()
|
self.lparams = llama_cpp.llama_context_default_params()
|
||||||
|
@ -188,7 +162,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
self.params.interactive_start = True
|
self.params.interactive_start = True
|
||||||
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
|
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
|
||||||
self.first_antiprompt.append(_ptn)
|
self.first_antiprompt.append(_ptn)
|
||||||
self.antiecho = IterSearch(_ptn)
|
self.antiecho = util.IterSearch(_ptn)
|
||||||
|
|
||||||
# enable interactive mode if reverse prompt or interactive start is specified
|
# enable interactive mode if reverse prompt or interactive start is specified
|
||||||
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
|
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
|
||||||
|
@ -256,14 +230,14 @@ n_keep = {self.params.n_keep}
|
||||||
- If you want to submit another line, end your input in '\\'.
|
- If you want to submit another line, end your input in '\\'.
|
||||||
|
|
||||||
""", file=sys.stderr)
|
""", file=sys.stderr)
|
||||||
self.set_color(CONSOLE_COLOR_PROMPT)
|
self.set_color(util.CONSOLE_COLOR_PROMPT)
|
||||||
|
|
||||||
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
||||||
|
|
||||||
|
|
||||||
# tokenize a prompt
|
# tokenize a prompt
|
||||||
def _tokenize(self, prompt, bos=True):
|
def _tokenize(self, prompt, bos=True):
|
||||||
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
||||||
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
|
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
|
||||||
return _arr[:_n]
|
return _arr[:_n]
|
||||||
|
|
||||||
|
@ -295,7 +269,6 @@ n_keep = {self.params.n_keep}
|
||||||
self.params.path_session = ""
|
self.params.path_session = ""
|
||||||
|
|
||||||
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
||||||
# REVIEW
|
|
||||||
if self.n_session_consumed < len(self.session_tokens):
|
if self.n_session_consumed < len(self.session_tokens):
|
||||||
for i in range(len(self.embd)):
|
for i in range(len(self.embd)):
|
||||||
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
|
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
|
||||||
|
@ -445,7 +418,7 @@ n_keep = {self.params.n_keep}
|
||||||
|
|
||||||
# reset color to default if we there is no pending user input
|
# reset color to default if we there is no pending user input
|
||||||
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
|
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
|
||||||
self.set_color(CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
|
|
||||||
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
|
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
|
||||||
# if antiprompt is present, stop
|
# if antiprompt is present, stop
|
||||||
|
@ -486,12 +459,12 @@ n_keep = {self.params.n_keep}
|
||||||
|
|
||||||
def exit(self):
|
def exit(self):
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.set_color(CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
|
|
||||||
# return past text
|
# return past text
|
||||||
def past(self):
|
def past(self):
|
||||||
for id in self.last_n_tokens[-self.n_past:]:
|
for id in self.last_n_tokens[-self.n_past:]:
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
|
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
|
||||||
|
|
||||||
# write input
|
# write input
|
||||||
def input(self, prompt: str):
|
def input(self, prompt: str):
|
||||||
|
@ -505,7 +478,29 @@ n_keep = {self.params.n_keep}
|
||||||
def output(self):
|
def output(self):
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
for id in self.generate():
|
for id in self.generate():
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
|
||||||
|
|
||||||
|
# Add remainder of missing bytes
|
||||||
|
if None in self.multibyte_fix:
|
||||||
|
self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char
|
||||||
|
|
||||||
|
# Return completed utf char
|
||||||
|
if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix:
|
||||||
|
yield (b"".join(self.multibyte_fix)).decode("utf8")
|
||||||
|
self.multibyte_fix = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Contains multi-byte UTF8
|
||||||
|
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||||
|
# Bitwise AND check
|
||||||
|
if pattern & int.from_bytes(cur_char) == pattern:
|
||||||
|
self.multibyte_fix = [cur_char] + ([None] * (num-1))
|
||||||
|
|
||||||
|
# Stop incomplete bytes from passing
|
||||||
|
if len(self.multibyte_fix) > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield cur_char.decode("utf8")
|
||||||
|
|
||||||
# read user input
|
# read user input
|
||||||
def read_input(self):
|
def read_input(self):
|
||||||
|
@ -521,7 +516,7 @@ n_keep = {self.params.n_keep}
|
||||||
self.params.input_echo = False
|
self.params.input_echo = False
|
||||||
|
|
||||||
while self.params.interactive:
|
while self.params.interactive:
|
||||||
self.set_color(CONSOLE_COLOR_USER_INPUT)
|
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
|
||||||
if (self.params.instruct):
|
if (self.params.instruct):
|
||||||
print('\n> ', end="")
|
print('\n> ', end="")
|
||||||
self.input(self.read_input())
|
self.input(self.read_input())
|
||||||
|
@ -529,13 +524,13 @@ n_keep = {self.params.n_keep}
|
||||||
print(self.params.input_prefix, end="")
|
print(self.params.input_prefix, end="")
|
||||||
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}")
|
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}")
|
||||||
print(self.params.input_suffix,end="")
|
print(self.params.input_suffix,end="")
|
||||||
self.set_color(CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in self.output():
|
for i in self.output():
|
||||||
print(i,end="",flush=True)
|
print(i,end="",flush=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.set_color(CONSOLE_COLOR_DEFAULT)
|
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||||
if not self.params.instruct:
|
if not self.params.instruct:
|
||||||
print(self.params.fix_prefix,end="")
|
print(self.params.fix_prefix,end="")
|
||||||
self.input(self.params.fix_prefix)
|
self.input(self.params.fix_prefix)
|
||||||
|
|
95
examples/low_level_api/util.py
Normal file
95
examples/low_level_api/util.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
|
||||||
|
ANSI_COLOR_RESET = "\x1b[0m"
|
||||||
|
ANSI_COLOR_YELLOW = "\x1b[33m"
|
||||||
|
ANSI_BOLD = "\x1b[1m"
|
||||||
|
ANSI_COLOR_GREEN = "\x1b[32m"
|
||||||
|
|
||||||
|
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
|
||||||
|
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
|
||||||
|
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
|
||||||
|
|
||||||
|
# Iterative search
|
||||||
|
# Actively searches and prevents a pattern from being returned
|
||||||
|
class IterSearch:
|
||||||
|
def __init__(self, pattern):
|
||||||
|
self.pattern = list(pattern)
|
||||||
|
self.buffer = []
|
||||||
|
|
||||||
|
def __call__(self, char):
|
||||||
|
self.buffer += [char]
|
||||||
|
|
||||||
|
if (self.pattern[:len(self.buffer)] == self.buffer):
|
||||||
|
if (len(self.buffer) >= len(self.pattern)):
|
||||||
|
self.buffer.clear()
|
||||||
|
return []
|
||||||
|
|
||||||
|
_tmp = self.buffer[:]
|
||||||
|
self.buffer.clear()
|
||||||
|
return _tmp
|
||||||
|
|
||||||
|
class Circle:
|
||||||
|
def __init__(self, size, default=0):
|
||||||
|
self.list = [default] * size
|
||||||
|
self.maxsize = size
|
||||||
|
self.size = 0
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def append(self, elem):
|
||||||
|
if self.size < self.maxsize:
|
||||||
|
self.list[self.size] = elem
|
||||||
|
self.size += 1
|
||||||
|
else:
|
||||||
|
self.list[self.offset] = elem
|
||||||
|
self.offset = (self.offset + 1) % self.maxsize
|
||||||
|
|
||||||
|
def __getitem__(self, val):
|
||||||
|
if isinstance(val, int):
|
||||||
|
if 0 > val or val >= self.size:
|
||||||
|
raise IndexError('Index out of range')
|
||||||
|
return self.list[val] if self.size < self.maxsize else self.list[(self.offset + val) % self.maxsize]
|
||||||
|
elif isinstance(val, slice):
|
||||||
|
start, stop, step = val.start, val.stop, val.step
|
||||||
|
if step is None:
|
||||||
|
step = 1
|
||||||
|
if start is None:
|
||||||
|
start = 0
|
||||||
|
if stop is None:
|
||||||
|
stop = self.size
|
||||||
|
if start < 0:
|
||||||
|
start = self.size + start
|
||||||
|
if stop < 0:
|
||||||
|
stop = self.size + stop
|
||||||
|
|
||||||
|
indices = range(start, stop, step)
|
||||||
|
return [self.list[(self.offset + i) % self.maxsize] for i in indices if i < self.size]
|
||||||
|
else:
|
||||||
|
raise TypeError('Invalid argument type')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
c = Circle(5)
|
||||||
|
|
||||||
|
c.append(1)
|
||||||
|
print(c.list)
|
||||||
|
print(c[:])
|
||||||
|
assert c[0] == 1
|
||||||
|
assert c[:5] == [1]
|
||||||
|
|
||||||
|
for i in range(2,5+1):
|
||||||
|
c.append(i)
|
||||||
|
print(c.list)
|
||||||
|
print(c[:])
|
||||||
|
assert c[0] == 1
|
||||||
|
assert c[:5] == [1,2,3,4,5]
|
||||||
|
|
||||||
|
for i in range(5+1,9+1):
|
||||||
|
c.append(i)
|
||||||
|
print(c.list)
|
||||||
|
print(c[:])
|
||||||
|
assert c[0] == 5
|
||||||
|
assert c[:5] == [5,6,7,8,9]
|
||||||
|
#assert c[:-5] == [5,6,7,8,9]
|
||||||
|
assert c[:10] == [5,6,7,8,9]
|
||||||
|
|
Loading…
Reference in a new issue