Better llama.cpp interoperability

Has some too many newline issues so WIP
This commit is contained in:
Mug 2023-04-06 15:30:57 +02:00
parent 283e59c5e9
commit 085cc92b1f
4 changed files with 357 additions and 120 deletions

0
examples/__init__.py Normal file
View file

135
examples/common.py Normal file
View file

@ -0,0 +1,135 @@
import os
import argparse
from dataclasses import dataclass, field
from typing import List, Optional
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
@dataclass
class GptParams:
seed: int = -1
n_threads: int = min(4, os.cpu_count() or 1)
n_predict: int = 128
repeat_last_n: int = 64
n_parts: int = -1
n_ctx: int = 512
n_batch: int = 8
n_keep: int = 0
top_k: int = 40
top_p: float = 0.95
temp: float = 0.80
repeat_penalty: float = 1.10
model: str = "./models/llama-7B/ggml-model.bin"
prompt: str = ""
input_prefix: str = " "
fix_prefix: str = ""
output_postfix: str = ""
input_echo: bool = True,
antiprompt: List[str] = field(default_factory=list)
memory_f16: bool = True
random_prompt: bool = False
use_color: bool = False
interactive: bool = False
embedding: bool = False
interactive_start: bool = False
instruct: bool = False
ignore_eos: bool = False
perplexity: bool = False
use_mlock: bool = False
mem_test: bool = False
verbose_prompt: bool = False
# Default instructions for Alpaca
# switch to "Human" and "Assistant" for Vicuna.
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
instruct_inp_suffix: str="\n\n### Response:\n\n",
def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
if params is None:
params = GptParams()
parser = argparse.ArgumentParser()
parser.add_argument("-h", "--help", action="store_true", help="show this help message and exit")
parser.add_argument("-s", "--seed", type=int, default=-1, help="",dest="seed")
parser.add_argument("-t", "--threads", type=int, default=1, help="",dest="n_threads")
parser.add_argument("-p", "--prompt", type=str, default="", help="",dest="prompt")
parser.add_argument("-f", "--file", type=str, default=None, help="")
parser.add_argument("-c", "--ctx_size", type=int, default=512, help="",dest="n_ctx")
parser.add_argument("--memory_f32", action="store_false", help="",dest="memory_f16")
parser.add_argument("--top_p", type=float, default=0.9, help="",dest="top_p")
parser.add_argument("--temp", type=float, default=1.0, help="",dest="temp")
parser.add_argument("--repeat_last_n", type=int, default=64, help="",dest="repeat_last_n")
parser.add_argument("--repeat_penalty", type=float, default=1.0, help="",dest="repeat_penalty")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="",dest="n_batch")
parser.add_argument("--keep", type=int, default=0, help="",dest="n_keep")
parser.add_argument("-m", "--model", type=str, help="",dest="model")
parser.add_argument(
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
)
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
parser.add_argument("--interactive-start", action="store_true", help="", dest="interactive_start")
parser.add_argument(
"--interactive-first",
action="store_true",
help="run in interactive mode and wait for input right away",
dest="interactive"
)
parser.add_argument(
"-ins",
"--instruct",
action="store_true",
help="run in instruction mode (use with Alpaca or Vicuna models)",
dest="instruct"
)
parser.add_argument(
"--color",
action="store_true",
help="colorise output to distinguish prompt and user input from generations",
dest="use_color"
)
parser.add_argument("--mlock", action="store_true",dest="use_mlock")
parser.add_argument("--mtest", action="store_true",dest="mem_test")
parser.add_argument(
"-r",
"--reverse-prompt",
type=str,
action='append',
help="run in interactive mode and poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
dest="antiprompt"
)
parser.add_argument("--perplexity", action="store_true", help="", dest="perplexity")
parser.add_argument("--ignore-eos", action="store_true", help="", dest="ignore_eos")
parser.add_argument("--n_parts", type=int, default=-1, help="", dest="n_parts")
parser.add_argument("--random-prompt", action="store_true", help="", dest="random_prompt")
parser.add_argument("--in-prefix", type=str, default=" ", help="", dest="input_prefix")
parser.add_argument("--fix-prefix", type=str, default=" ", help="", dest="fix_prefix")
parser.add_argument("--out-postfix", type=str, default="", help="", dest="output_postfix")
parser.add_argument("--input-noecho", action="store_false", help="", dest="input_echo")
args = parser.parse_args(argv)
return args
def gpt_random_prompt(rng):
return [
"So",
"Once upon a time",
"When",
"The",
"After",
"If",
"import",
"He",
"She",
"They",
][rng % 10]
if __name__ == "__main__":
print(GptParams(gpt_params_parse()))

View file

View file

@ -12,102 +12,182 @@ Quirks:
You should also still be feeding the model with a "primer" prompt that You should also still be feeding the model with a "primer" prompt that
shows it the expected format. shows it the expected format.
""" """
import sys
from time import time
from os import cpu_count
import llama_cpp import llama_cpp
from common import GptParams, gpt_params_parse, gpt_random_prompt
ANSI_COLOR_RESET = "\x1b[0m"
ANSI_COLOR_YELLOW = "\x1b[33m"
ANSI_BOLD = "\x1b[1m"
ANSI_COLOR_GREEN = "\x1b[32m"
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
# A LLaMA interactive session # A LLaMA interactive session
class LLaMAInteract: class LLaMAInteract:
def __init__(self, def __init__(self, params: GptParams) -> None:
primer: str="",
model: str="./models/30B/ggml-model-q4_0.bin",
instruct: bool=False,
n_ctx: int=1024,
seed: int=0,
n_threads: int=8,
antiprompt: list[str]=[],
input_echo: bool=True,
n_predict: int=20,
n_keep: int=0,
n_batch: int=8,
repeat_last_n: int=64,
top_k: int=50,
top_p: float=1.,
temp: float=1.0,
repeat_penalty: float=1,
init_break: bool=True,
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
instruct_inp_suffix: str="\n\n### Response:\n\n",
) -> None:
# input args # input args
self.instruct = instruct self.params = params
self.n_threads = n_threads
self.input_echo = input_echo if (self.params.perplexity):
self.n_predict = n_predict raise NotImplementedError("""************
self.n_keep = n_keep please use the 'perplexity' tool for perplexity calculations
self.n_batch = n_batch ************""")
self.repeat_last_n = repeat_last_n
self.top_k=top_k if (self.params.embedding):
self.top_p=top_p raise NotImplementedError("""************
self.temp=temp please use the 'embedding' tool for embedding calculations
self.repeat_penalty=repeat_penalty ************""")
self.init_break = init_break
if (self.params.n_ctx > 2048):
print(f"""warning: model does not support \
context sizes greater than 2048 tokens ({self.params.n_ctx} \
specified) expect poor results""", file=sys.stderr)
if (self.params.seed <= 0):
self.params.seed = int(time())
print(f"seed = {self.params.seed}", file=sys.stderr)
if (self.params.random_prompt):
self.params.prompt = gpt_random_prompt(self.params.seed)
# runtime args # runtime args
self.input_consumed = 0 self.input_consumed = 0
self.embd = [] self.embd = []
self.embd_inp = []
self.n_past = 0 self.n_past = 0
self.first_antiprompt = [] self.first_antiprompt = []
self.remaining_tokens = self.n_predict self.remaining_tokens = self.params.n_predict
self.output_echo = input_echo self.output_echo = self.params.input_echo
# model load # model load
self.lparams = llama_cpp.llama_context_default_params() self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = n_ctx self.lparams.n_ctx = self.params.n_ctx
self.lparams.seed = seed self.lparams.n_parts = self.params.n_parts
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams) self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
if (self.ctx == 0):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
# determine the required inference memory per token: # determine the required inference memory per token:
tmp = [0, 1, 2, 3] if (self.params.mem_test):
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
# determine newline token llama_cpp.llama_print_timings(self.ctx)
self.llama_token_newline = self._tokenize("\n", False) self.exit()
self.inp_prefix = self._tokenize(instruct_inp_prefix) return
self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
# add instruction as antiprompt
if (self.instruct):
self.first_antiprompt.append(self._tokenize(instruct_inp_prefix.strip(), False))
# primer feed
if (len(primer) > 0):
self.embd_inp += self._tokenize(primer)
# number of tokens to keep when resetting context
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
self.n_keep = len(self.embd_inp)
# create internal context # create internal context
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx) self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
# Add a space in front of the first character to match OG llama tokenizer behavior
self.params.prompt = " " + self.params.prompt
# tokenize the prompt
self.embd_inp = self._tokenize(self.params.prompt)
if (len(self.embd_inp) > self.params.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
# number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp)
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user
if (self.params.instruct):
self.params.interactive_start = True
self.first_antiprompt.append(self._tokenize(self.params.instruct_inp_prefix.strip(), False))
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
self.params.interactive = True
# determine newline token
self.llama_token_newline = self._tokenize("\n", False)
if (self.params.verbose_prompt):
print(f"""
prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
for i in range(len(self.embd_inp)):
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
if (self.params.interactive):
print("interactive mode on.", file=sys.stderr)
if (len(self.params.antiprompt) > 0):
for antiprompt in self.params.antiprompt:
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr)
if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
print(f"""sampling: temp = {self.params.temp},\
top_k = {self.params.top_k},\
top_p = {self.params.top_p},\
repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty}
generate: n_ctx = {self.n_ctx}, \
n_batch = {self.params.n_batch}, \
n_predict = {self.params.n_predict}, \
n_keep = {self.params.n_keep}
""", file=sys.stderr)
# determine antiprompt tokens # determine antiprompt tokens
for i in antiprompt: for i in self.params.antiprompt:
self.first_antiprompt.append(self._tokenize(i, False)) self.first_antiprompt.append(self._tokenize(i, False))
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
if (params.interactive):
print("""== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\\'.
""", file=sys.stderr)
self.set_color(CONSOLE_COLOR_PROMPT)
# tokenize a prompt # tokenize a prompt
def _tokenize(self, prompt, bos=True): def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))() _arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n] return _arr[:_n]
# if an antiprompt is present
def use_antiprompt(self): def use_antiprompt(self):
return len(self.first_antiprompt) > 0 return len(self.first_antiprompt) > 0
def set_color(self, c):
if (self.params.use_color):
print(c)
# generate tokens # generate tokens
def generate(self): def generate(self):
while self.remaining_tokens > 0 or self.use_antiprompt(): while self.remaining_tokens > 0 or self.params.interactive:
# predict # predict
if len(self.embd) > 0: if len(self.embd) > 0:
# infinite text generation via context swapping # infinite text generation via context swapping
@ -115,8 +195,8 @@ class LLaMAInteract:
# - take the n_keep first tokens from the original prompt (via n_past) # - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch # - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx): if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.n_keep n_left = self.n_past - self.params.n_keep
self.n_past = self.n_keep self.n_past = self.params.n_keep
# insert n_left/2 tokens at the start of embd from last_n_tokens # insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[ _insert = self.last_n_tokens[
@ -125,7 +205,7 @@ class LLaMAInteract:
self.embd = _insert + self.embd self.embd = _insert + self.embd
if (llama_cpp.llama_eval( if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
) != 0): ) != 0):
raise Exception("Failed to llama_eval!") raise Exception("Failed to llama_eval!")
@ -133,24 +213,28 @@ class LLaMAInteract:
self.embd = [] self.embd = []
if len(self.embd_inp) <= self.input_consumed: if len(self.embd_inp) <= self.input_consumed:
# out of user input, sample next token # out of user input, sample next token
_arr = self.last_n_tokens[-min(self.repeat_last_n, self.n_past):]
#TODO: self.params.ignore_eos
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k( id = llama_cpp.llama_sample_top_p_top_k(
self.ctx, self.ctx,
(llama_cpp.llama_token * len(_arr))(*_arr), (llama_cpp.llama_token * len(_arr))(*_arr),
len(_arr), len(_arr),
self.top_k, self.params.top_k,
self.top_p, self.params.top_p,
self.temp, self.params.temp,
self.repeat_penalty, self.params.repeat_penalty,
) )
self.last_n_tokens.pop(0) self.last_n_tokens.pop(0)
self.last_n_tokens.append(id) self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode # replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct): if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
id = self.llama_token_newline[0] id = self.llama_token_newline[0]
# tokenize and inject first reverse prompt if (self.use_antiprompt()):
self.embd_inp += self.first_antiprompt[0] # tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
# add it to the context # add it to the context
self.embd.append(id) self.embd.append(id)
@ -162,7 +246,7 @@ class LLaMAInteract:
self.remaining_tokens -= 1 self.remaining_tokens -= 1
else: else:
# output to console if input echo is on # output to console if input echo is on
self.output_echo = self.input_echo self.output_echo = self.params.input_echo
# some user input remains from prompt or interaction, forward it to processing # some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed: while len(self.embd_inp) > self.input_consumed:
@ -170,7 +254,7 @@ class LLaMAInteract:
self.last_n_tokens.pop(0) self.last_n_tokens.pop(0)
self.last_n_tokens.append(self.embd_inp[self.input_consumed]) self.last_n_tokens.append(self.embd_inp[self.input_consumed])
self.input_consumed += 1 self.input_consumed += 1
if len(self.embd) >= self.n_batch: if len(self.embd) >= self.params.n_batch:
break break
# display tokens # display tokens
@ -178,7 +262,11 @@ class LLaMAInteract:
for id in self.embd: for id in self.embd:
yield id yield id
if (len(self.embd_inp) <= self.input_consumed): # reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
self.set_color(CONSOLE_COLOR_DEFAULT)
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop # if antiprompt is present, stop
if (self.use_antiprompt()): if (self.use_antiprompt()):
if True in [ if True in [
@ -188,26 +276,36 @@ class LLaMAInteract:
break break
# if we are using instruction mode, and we have processed the initial prompt # if we are using instruction mode, and we have processed the initial prompt
if (self.init_break): if (self.n_past > 0 and self.params.interactive_start):
break break
# if end of generation # end of text token
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
if (not self.params.instruct):
for i in " [end of text]\n":
yield i
break break
# respect n_predict even if antiprompt is present # respect n_predict even if antiprompt is present
if (self.use_antiprompt() and self.remaining_tokens <= 0 and self.n_predict != -1): if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
if not self.instruct: # If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self.params.instruct:
self.embd_inp += self.first_antiprompt[0] self.embd_inp += self.first_antiprompt[0]
self.n_remain = self.params.n_predict
break break
self.init_break = False self.params.interactive_start = False
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, type, value, tb): def __exit__(self, type, value, tb):
self.exit()
def exit(self):
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.set_color(CONSOLE_COLOR_DEFAULT)
# return past text # return past text
def past(self): def past(self):
@ -216,18 +314,51 @@ class LLaMAInteract:
# write input # write input
def input(self, prompt: str): def input(self, prompt: str):
if (self.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix): if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
self.embd_inp += self.inp_prefix self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt) self.embd_inp += self._tokenize(prompt)
if (self.instruct): if (self.params.instruct):
self.embd_inp += self.inp_suffix self.embd_inp += self.inp_suffix
# write output # write output
def output(self): def output(self):
self.remaining_tokens = self.n_predict self.remaining_tokens = self.params.n_predict
for id in self.generate(): for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# read user input
def read_input(self):
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
# interactive mode
def interact(self):
for i in self.output():
print(i,end="",flush=True)
self.params.input_echo = False
while self.params.interactive:
self.set_color(CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
print('\n> ', end="")
self.input(self.read_input())
else:
print(self.params.input_prefix, end="")
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.output_postfix}")
print(self.params.output_postfix,end="")
self.set_color(CONSOLE_COLOR_DEFAULT)
try:
for i in self.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
self.set_color(CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
print(self.params.fix_prefix,end="")
self.input(self.params.fix_prefix)
if __name__ == "__main__": if __name__ == "__main__":
from datetime import datetime from datetime import datetime
@ -252,41 +383,12 @@ The transcript only includes text, it does not include markup like HTML and Mark
{USER_NAME}: Name a color. {USER_NAME}: Name a color.
{AI_NAME}: Blue {AI_NAME}: Blue
{USER_NAME}:""" {USER_NAME}:"""
args = gpt_params_parse()
params = GptParams(args)
print("Loading model...") if (args.file):
with LLaMAInteract(prompt, with open(args.file) as f:
model="./models/30B/ggml-model-q4_0.bin", params.prompt = f.read()
n_ctx=2048,
antiprompt=[f"\n{USER_NAME}:"],
repeat_last_n=256,
n_predict=2048,
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
) as m:
print("Loaded model!")
for i in m.output(): with LLaMAInteract() as m:
print(i,end="",flush=True) m.interact()
m.input_echo = False
def inp():
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
while True:
if (m.instruct):
print('\n> ', end="")
m.input(inp())
else:
print(f" ", end="")
m.input(f" {inp()}{AI_NAME}:")
print(f"{AI_NAME}: ",end="")
try:
for i in m.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
if not m.instruct:
print(f"\n{USER_NAME}:",end="")
m.input(f"\n{USER_NAME}:")