Update low level examples

This commit is contained in:
Mug 2023-05-04 18:33:08 +02:00
parent a02aa121da
commit 0e9f227afd
6 changed files with 486 additions and 92 deletions

View file

@ -0,0 +1,70 @@
#!/bin/python
import sys, os, datetime
from common import GptParams
from low_level_api_chat_cpp import LLaMAInteract
def env_or_def(env, default):
if (env in os.environ):
return os.environ[env]
return default
AI_NAME = env_or_def("AI_NAME", "ChatLLaMa")
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
USER_NAME = env_or_def("USER_NAME", "USER")
N_PREDICTS = int(env_or_def("N_PREDICTS", "2048"))
N_THREAD = int(env_or_def("N_THREAD", "8"))
today = datetime.datetime.today()
DATE_YEAR=today.strftime("%Y")
DATE_TIME=today.strftime("%H:%M")
prompt=f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}'s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
{USER_NAME}: Hello, {AI_NAME}!
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
{USER_NAME}: What year is it?
{AI_NAME}: We are in {DATE_YEAR}.
{USER_NAME}: Please tell me the largest city in Europe.
{AI_NAME}: The largest city in Europe is Moscow, the capital of Russia.
{USER_NAME}: What can you tell me about Moscow?
{AI_NAME}: Moscow, on the Moskva River in western Russia, is the nation's cosmopolitan capital. In its historic core is the Kremlin, a complex that's home to the president and tsarist treasures in the Armoury. Outside its walls is Red Square, Russias symbolic center.
{USER_NAME}: What is a cat?
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: How do I pass command line arguments to a Node.js program?
{AI_NAME}: The arguments are stored in process.argv.
argv[0] is the path to the Node. js executable.
argv[1] is the path to the script file.
argv[2] is the first argument passed to the script.
argv[3] is the second argument passed to the script and so on.
{USER_NAME}: Name a color.
{AI_NAME}: Blue.
{USER_NAME}: What time is it?
{AI_NAME}: It is {DATE_TIME}.
{USER_NAME}:""" + " ".join(sys.argv[1:])
print("Loading model...")
params = GptParams(
n_ctx=2048,
temp=0.7,
top_k=40,
top_p=0.5,
repeat_last_n=256,
n_batch=1024,
repeat_penalty=1.17647,
model=MODEL,
n_threads=N_THREAD,
n_predict=N_PREDICTS,
use_color=True,
interactive=True,
antiprompt=[f"{USER_NAME}:"],
input_prefix=" ",
prompt=prompt,
)
with LLaMAInteract(params) as m:
m.interact()

View file

@ -0,0 +1,59 @@
#!/bin/python
import sys, os
from common import GptParams
from low_level_api_chat_cpp import LLaMAInteract
def env_or_def(env, default):
if (env in os.environ):
return os.environ[env]
return default
AI_NAME = env_or_def("AI_NAME", "Miku")
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
USER_NAME = env_or_def("USER_NAME", "Anon")
N_PREDICTS = int(env_or_def("N_PREDICTS", "4096"))
N_THREAD = int(env_or_def("N_THREAD", "0"))
prompt=f"""This is a transcript of a 1000 page, never ending conversation between {USER_NAME} and the cute and helpful AI assistant {AI_NAME}. {AI_NAME} is a girl who is an AI running on the users computer.
{AI_NAME} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next.
{AI_NAME} is always coherent and makes sense, but if she isn't sure if what she is saying is correct she will ask the user for help.
{AI_NAME} is a very helpful AI and will help the user with anything they need, she is also very friendly and will try to make the user feel better if they are sad.
{AI_NAME} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.
The conversation is only between {USER_NAME} and {AI_NAME}
The conversation is only through text, so {AI_NAME} can't see {USER_NAME}'s face or hear his voice.
{AI_NAME} can only communicate through text, so she can't send images or videos.
{USER_NAME}: Hello!
{AI_NAME}: /think I wonder what I should say to {USER_NAME}? This is the first time we talk so it's important that I make a good first impression!
{AI_NAME}: Hi! I am {AI_NAME}, your new AI friend, assistant(or whatever you like!), it's so nice to meet you! ^_^
{AI_NAME}: I might be an AI, but I also have feelings, so please be nice to me! :)
{USER_NAME}: I'll make sure to be nice to you! I'm so happy to have you as my assistant!
{AI_NAME}: /think It sounds like {USER_NAME} is happy to have me as their assistant! I'm so happy too! ^_^ Glad that whole emotion thing didn't scare him off!
{AI_NAME}: /think I wonder what {USER_NAME} likes to do in his free time? I should ask him about that!
{AI_NAME}: What do you like to do in your free time? ^_^
{USER_NAME}:""" + " ".join(sys.argv[1:])
print("Loading model...")
params = GptParams(
n_batch=1024,
n_ctx=2048,
n_keep=-1,
repeat_last_n=256,
repeat_penalty=1.17647,
temp=0.7,
top_k=40,
top_p=0.5,
model=MODEL,
n_predict=N_PREDICTS,
use_color=True,
interactive=True,
antiprompt=[f"{USER_NAME}:"],
prompt=prompt,
)
if N_THREAD > 0:
params.n_threads = N_THREAD
with LLaMAInteract(params) as m:
m.interact()

View file

@ -0,0 +1,49 @@
#!/bin/python
import sys, os, datetime
from common import GptParams
from low_level_api_chat_cpp import LLaMAInteract
def env_or_def(env, default):
if (env in os.environ):
return os.environ[env]
return default
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
prompt=f"""You run in a loop of Thought, Action, Observation.
At the end of the loop either Answer or restate your Thought and Action.
Use Thought to describe your thoughts about the question you have been asked.
Use Action to run one of these actions available to you:
- calculate[python math expression]
Observation will be the result of running those actions
Question: What is 4 * 7 / 3?
Thought: Do I need to use an action? Yes, I use calculate to do math
Action: calculate[4 * 7 / 3]
Observation: 9.3333333333
Thought: Do I need to use an action? No, have the result
Answer: The calculate tool says it is 9.3333333333
Question: What is capital of france?
Thought: Do I need to use an action? No, I know the answer
Answer: Paris is the capital of France
Question:""" + " ".join(sys.argv[1:])
print("Loading model...")
params = GptParams(
interactive=True,
interactive_start=True,
top_k=10000,
temp=0.2,
repeat_penalty=1,
n_threads=7,
n_ctx=2048,
antiprompt=["Question:","Observation:"],
model=MODEL,
input_prefix=" ",
n_predict=-1,
prompt=prompt,
)
with LLaMAInteract(params) as m:
m.interact()

View file

@ -1,8 +1,9 @@
import os import os
import argparse import argparse
import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp # Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
@ -12,23 +13,35 @@ class GptParams:
seed: int = -1 seed: int = -1
n_threads: int = min(4, os.cpu_count() or 1) n_threads: int = min(4, os.cpu_count() or 1)
n_predict: int = 128 n_predict: int = 128
repeat_last_n: int = 64
n_parts: int = -1 n_parts: int = -1
n_ctx: int = 512 n_ctx: int = 512
n_batch: int = 8 n_batch: int = 8
n_keep: int = 0 n_keep: int = 0
ignore_eos: bool = False
logit_bias: dict[int, float] = field(default_factory=dict)
top_k: int = 40 top_k: int = 40
top_p: float = 0.95 top_p: float = 0.95
tfs_z: float = 1.00
typical_p: float = 1.00
temp: float = 0.80 temp: float = 0.80
repeat_penalty: float = 1.10 repeat_penalty: float = 1.10
repeat_last_n: int = 64
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
mirostat: int = 0
mirostat_tau: float = 5.0
mirostat_eta: float = 0.1
model: str = "./models/llama-7B/ggml-model.bin" model: str = "./models/llama-7B/ggml-model.bin"
prompt: str = "" prompt: str = ""
path_session: str = ""
input_prefix: str = " " input_prefix: str = " "
antiprompt: List[str] = field(default_factory=list) antiprompt: List[str] = field(default_factory=list)
lora_adapter: str = ""
lora_base: str = ""
memory_f16: bool = True memory_f16: bool = True
random_prompt: bool = False random_prompt: bool = False
use_color: bool = False use_color: bool = False
@ -38,7 +51,7 @@ class GptParams:
interactive_start: bool = False interactive_start: bool = False
instruct: bool = False instruct: bool = False
ignore_eos: bool = False penalize_nl: bool = True
perplexity: bool = False perplexity: bool = False
use_mmap: bool = True use_mmap: bool = True
use_mlock: bool = False use_mlock: bool = False
@ -61,59 +74,42 @@ class GptParams:
instruct_inp_suffix: str="\n\n### Response:\n\n" instruct_inp_suffix: str="\n\n### Response:\n\n"
def gpt_params_parse(argv = None, params: Optional[GptParams] = None): def gpt_params_parse(argv = None):
if params is None:
params = GptParams()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed") parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed")
parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads") parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads")
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt") parser.add_argument("-n", "--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file") parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts")
parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx") parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx")
parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch") parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep") parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep")
parser.add_argument(
"-l",
"--logit-bias",
type=str,
action='append',
help="--logit-bias TOKEN_ID(+/-)BIAS",
dest="logit_bias_str"
)
parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos")
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
parser.add_argument("--tfs", type=float, default=1.0, help="tail free sampling, parameter z (1.0 = disabled)",dest="tfs_z")
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau",dest="mirostat_tau")
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model") parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
parser.add_argument( parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive" parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
) parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
parser.add_argument("--embedding", action="store_true", help="", dest="embedding") parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
parser.add_argument(
"--interactive-start",
action="store_true",
help="run in interactive mode",
dest="interactive"
)
parser.add_argument(
"--interactive-first",
action="store_true",
help="run in interactive mode and wait for input right away",
dest="interactive_start"
)
parser.add_argument(
"-ins",
"--instruct",
action="store_true",
help="run in instruction mode (use with Alpaca or Vicuna models)",
dest="instruct"
)
parser.add_argument(
"--color",
action="store_true",
help="colorise output to distinguish prompt and user input from generations",
dest="use_color"
)
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")
parser.add_argument( parser.add_argument(
"-r", "-r",
"--reverse-prompt", "--reverse-prompt",
@ -122,16 +118,71 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).", help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
dest="antiprompt" dest="antiprompt"
) )
parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos") parser.add_argument("--lora", type=str, default="", help="apply LoRA adapter (implies --no-mmap)", dest="lora_adapter")
parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts") parser.add_argument("--lora-base", type=str, default="", help="optional model to use as a base for the layers modified by the LoRA adapter", dest="lora_base")
parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt") parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt")
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix") parser.add_argument(
"--color",
action="store_true",
help="colorise output to distinguish prompt and user input from generations",
dest="use_color"
)
parser.add_argument(
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
)
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
parser.add_argument(
"--interactive-first",
action="store_true",
help="run in interactive mode and wait for input right away",
dest="interactive_start"
)
parser.add_argument(
"-ins",
"--instruct",
action="store_true",
help="run in instruction mode (use with Alpaca or Vicuna models)",
dest="instruct"
)
parser.add_argument("--no-penalize-nl", action="store_false", help="do not penalize newline token", dest="penalize_nl")
parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")
#Custom args
parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix") parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix")
parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix") parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix")
parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo") parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo")
parser.add_argument(
"--interactive-start",
action="store_true",
help="run in interactive mode",
dest="interactive"
)
args = parser.parse_args(argv) args = parser.parse_args(argv)
return args
logit_bias_str = args.logit_bias_str
delattr(args, "logit_bias_str")
params = GptParams(**vars(args))
if (params.lora_adapter):
params.use_mmap = False
if (logit_bias_str != None):
for i in logit_bias_str:
if (m := re.match(r"(\d+)([-+]\d+)", i)):
params.logit_bias[int(m.group(1))] = int(m.group(2))
return params
def gpt_random_prompt(rng): def gpt_random_prompt(rng):
return [ return [
@ -148,4 +199,4 @@ def gpt_random_prompt(rng):
][rng % 10] ][rng % 10]
if __name__ == "__main__": if __name__ == "__main__":
print(GptParams(gpt_params_parse())) print(gpt_params_parse())

View file

@ -10,9 +10,10 @@ Quirks:
You should also still be feeding the model with a "primer" prompt that You should also still be feeding the model with a "primer" prompt that
shows it the expected format. shows it the expected format.
""" """
import ctypes
import sys import sys
from time import time from time import time
from os import cpu_count from os import cpu_count, path
import llama_cpp import llama_cpp
from common import GptParams, gpt_params_parse, gpt_random_prompt from common import GptParams, gpt_params_parse, gpt_random_prompt
@ -77,6 +78,7 @@ specified) expect poor results""", file=sys.stderr)
# runtime args # runtime args
self.input_consumed = 0 self.input_consumed = 0
self.n_past = 0 self.n_past = 0
self.n_session_consumed = 0
self.first_antiprompt = [] self.first_antiprompt = []
self.remaining_tokens = self.params.n_predict self.remaining_tokens = self.params.n_predict
self.output_echo = self.params.input_echo self.output_echo = self.params.input_echo
@ -94,6 +96,19 @@ specified) expect poor results""", file=sys.stderr)
if (not self.ctx): if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'") raise RuntimeError(f"error: failed to load model '{self.params.model}'")
if (self.params.ignore_eos):
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
if (len(self.params.lora_adapter) > 0):
if (llama_cpp.llama_apply_lora_from_file(
self.ctx,
self.params.lora_adapter,
self.params.lora_base if len(self.params.lora_base) > 0 else None,
self.params.n_threads
) != 0):
print("error: failed to apply lora adapter")
return
print(file=sys.stderr) print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) | {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
@ -117,13 +132,49 @@ specified) expect poor results""", file=sys.stderr)
with open(self.params.file) as f: with open(self.params.file) as f:
self.params.prompt = f.read() self.params.prompt = f.read()
self.session_tokens: list[llama_cpp.llama_token] = []
if (len(self.params.path_session) > 0):
print(f"attempting to load saved session from '{self.params.path_session}'", file=sys.stderr)
if (path.exists(self.params.path_session)):
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
_n_token_count_out = llama_cpp.c_int()
if (llama_cpp.llama_load_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
_session_tokens,
self.params.n_ctx,
ctypes.byref(_n_token_count_out)
) != 0):
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
return
self.session_tokens = _session_tokens[:_n_token_count_out]
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
else:
print(f"session file does not exist, will create", file=sys.stderr)
# tokenize the prompt # tokenize the prompt
self.embd = [] self.embd = []
self.embd_inp = self._tokenize(self.params.prompt) self.embd_inp = self._tokenize(self.params.prompt)
if (len(self.embd_inp) > self.params.n_ctx - 4): if (len(self.embd_inp) > self.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})") raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
# debug message about similarity of saved session, if applicable
n_matching_session_tokens = 0
if len(self.session_tokens) > 0:
for id in self.session_tokens:
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
break
n_matching_session_tokens += 1
if n_matching_session_tokens >= len(self.embd_inp):
print(f"session file has exact match for prompt!")
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
else:
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
# number of tokens to keep when resetting context # number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct): if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp) self.params.n_keep = len(self.embd_inp)
@ -132,6 +183,7 @@ specified) expect poor results""", file=sys.stderr)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False) self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user # in instruct mode, we inject a prefix and a suffix to each input by the user
self.antiecho = None
if (self.params.instruct): if (self.params.instruct):
self.params.interactive_start = True self.params.interactive_start = True
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False) _ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
@ -171,16 +223,24 @@ number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
if len(self.params.input_prefix) > 0: if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr) print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
print(f"""sampling: temp = {self.params.temp},\ print(f"""sampling: repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty},\
presence_penalty = {self.params.presence_penalty},\
frequency_penalty = {self.params.frequency_penalty},\
top_k = {self.params.top_k},\ top_k = {self.params.top_k},\
tfs_z = {self.params.tfs_z},\
top_p = {self.params.top_p},\ top_p = {self.params.top_p},\
repeat_last_n = {self.params.repeat_last_n},\ typical_p = {self.params.typical_p},\
repeat_penalty = {self.params.repeat_penalty} temp = {self.params.temp},\
mirostat = {self.params.mirostat},\
mirostat_lr = {self.params.mirostat_eta},\
mirostat_ent = {self.params.mirostat_tau},\
generate: n_ctx = {self.n_ctx}, \ generate: n_ctx = {self.n_ctx},\
n_batch = {self.params.n_batch}, \ n_batch = {self.params.n_batch},\
n_predict = {self.params.n_predict}, \ n_predict = {self.params.n_predict},\
n_keep = {self.params.n_keep} n_keep = {self.params.n_keep}
""", file=sys.stderr) """, file=sys.stderr)
# determine antiprompt tokens # determine antiprompt tokens
@ -198,6 +258,9 @@ n_keep = {self.params.n_keep}
""", file=sys.stderr) """, file=sys.stderr)
self.set_color(CONSOLE_COLOR_PROMPT) self.set_color(CONSOLE_COLOR_PROMPT)
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
# tokenize a prompt # tokenize a prompt
def _tokenize(self, prompt, bos=True): def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))() _arr = (llama_cpp.llama_token * (len(prompt) + 1))()
@ -229,31 +292,117 @@ n_keep = {self.params.n_keep}
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd) self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
] ]
self.embd = _insert + self.embd self.embd = _insert + self.embd
self.params.path_session = ""
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
# REVIEW
if self.n_session_consumed < len(self.session_tokens):
for i in range(len(self.embd)):
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
self.session_tokens = self.session_tokens[:self.n_session_consumed]
break
self.n_past += 1
self.n_session_consumed += 1
if self.n_session_consumed >= len(self.session_tokens):
i += 1
break
if i > 0:
self.embd = self.embd[i:]
# evaluate tokens in batches
# embd is typically prepared beforehand to fit within a batch, but not always
#TODO BUG: The batching code causes nonsensical generation
"""for i in range(0, len(self.embd), self.params.n_batch):
n_eval = self.params.n_batch
_arr = (llama_cpp.llama_token * n_eval)(*self.embd[i:i + n_eval])
if llama_cpp.llama_eval(self.ctx, _arr, n_eval, self.n_past, self.params.n_threads) != 0:
print(f"failed to eval")
return
self.n_past += n_eval"""
if (llama_cpp.llama_eval( if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
) != 0): ) != 0):
raise Exception("Failed to llama_eval!") raise Exception("Failed to llama_eval!")
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
self.session_tokens.extend(self.embd)
self.n_session_consumed = len(self.session_tokens)
self.n_past += len(self.embd) self.n_past += len(self.embd)
self.embd = [] self.embd = []
if len(self.embd_inp) <= self.input_consumed: if len(self.embd_inp) <= self.input_consumed: #&& !is_interacting
# out of user input, sample next token # out of user input, sample next token
top_k = llama_cpp.llama_n_vocab(self.ctx) if self.params.top_k <= 0 else self.params.top_k
repeat_last_n = self.n_ctx if self.params.repeat_last_n < 0 else self.params.repeat_last_n
if (self.params.ignore_eos): # optionally save the session on first sample (for faster prompt loading next time)
logits = llama_cpp.llama_get_logits(self.ctx) if len(self.params.path_session) > 0 and self.need_to_save_session:
logits[llama_cpp.llama_token_eos()] = llama_cpp.c_float(0) self.need_to_save_session = False
llama_cpp.llama_save_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
self.session_tokens,
len(self.session_tokens)
)
id = 0
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
# Apply params.logit_bias map
for key, value in self.params.logit_bias.items():
logits[key] += value
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
# Apply penalties
nl_logit = logits[llama_cpp.llama_token_nl()]
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
_arr,
last_n_repeat, self.params.repeat_penalty)
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
_arr,
last_n_repeat, self.params.frequency_penalty, self.params.presence_penalty)
if not self.params.penalize_nl:
logits[llama_cpp.llama_token_nl()] = nl_logit
if self.params.temp <= 0:
# Greedy sampling
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
else:
if self.params.mirostat == 1:
mirostat_mu = 2.0 * self.params.mirostat_tau
mirostat_m = 100
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, mirostat_mu)
elif self.params.mirostat == 2:
mirostat_mu = 2.0 * self.params.mirostat_tau
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_mu)
else:
# Temperature sampling
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k)
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, self.params.tfs_z)
llama_cpp.llama_sample_typical(self.ctx, candidates_p, self.params.typical_p)
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, self.params.top_p)
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
# print("`{}`".format(candidates_p.size))
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k(
self.ctx,
(llama_cpp.llama_token * len(_arr))(*_arr),
len(_arr),
self.params.top_k,
self.params.top_p,
self.params.temp,
self.params.repeat_penalty,
)
self.last_n_tokens.pop(0) self.last_n_tokens.pop(0)
self.last_n_tokens.append(id) self.last_n_tokens.append(id)
@ -288,7 +437,7 @@ n_keep = {self.params.n_keep}
# display tokens # display tokens
if self.output_echo: if self.output_echo:
for id in self.embd: for id in self.embd:
if self.params.instruct: if self.antiecho != None:
for r in self.antiecho(id): for r in self.antiecho(id):
yield r yield r
else: else:
@ -316,7 +465,7 @@ n_keep = {self.params.n_keep}
if (not self.params.instruct): if (not self.params.instruct):
for i in self.llama_token_eot: for i in self.llama_token_eot:
yield i yield i
break break
# respect n_predict even if antiprompt is present # respect n_predict even if antiprompt is present
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1): if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
@ -356,7 +505,7 @@ n_keep = {self.params.n_keep}
def output(self): def output(self):
self.remaining_tokens = self.params.n_predict self.remaining_tokens = self.params.n_predict
for id in self.generate(): for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# read user input # read user input
def read_input(self): def read_input(self):
@ -415,8 +564,7 @@ The transcript only includes text, it does not include markup like HTML and Mark
{USER_NAME}: Name a color. {USER_NAME}: Name a color.
{AI_NAME}: Blue {AI_NAME}: Blue
{USER_NAME}:""" {USER_NAME}:"""
args = gpt_params_parse() params = gpt_params_parse()
params = GptParams(**vars(args))
with LLaMAInteract(params) as m: with LLaMAInteract(params) as m:
m.interact() m.interact()

View file

@ -37,6 +37,10 @@ embd = []
last_n_size = 64 last_n_size = 64
last_n_tokens_data = [0] * last_n_size last_n_tokens_data = [0] * last_n_size
n_batch = 24 n_batch = 24
last_n_repeat = 64
repeat_penalty = 1
frequency_penalty = 0.0
presence_penalty = 0.0
while remaining_tokens > 0: while remaining_tokens > 0:
if len(embd) > 0: if len(embd) > 0:
@ -47,15 +51,28 @@ while remaining_tokens > 0:
n_past += len(embd) n_past += len(embd)
embd = [] embd = []
if len(embd_inp) <= input_consumed: if len(embd_inp) <= input_consumed:
id = llama_cpp.llama_sample_top_p_top_k( logits = llama_cpp.llama_get_logits(ctx)
ctx, n_vocab = llama_cpp.llama_n_vocab(ctx)
(llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data),
len(last_n_tokens_data), _arr = (llama_cpp.llama_token_data * n_vocab)(*[
40, llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
0.8, for token_id in range(n_vocab)
0.2, ])
1.0 / 0.85, candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
)
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
_arr,
last_n_repeat, repeat_penalty)
llama_cpp.llama_sample_frequency_and_presence_penalties(ctx, candidates_p,
_arr,
last_n_repeat, frequency_penalty, presence_penalty)
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
id = llama_cpp.llama_sample_token(ctx, candidates_p)
last_n_tokens_data = last_n_tokens_data[1:] + [id] last_n_tokens_data = last_n_tokens_data[1:] + [id]
embd.append(id) embd.append(id)
input_noecho = False input_noecho = False