llama.cpp/examples/low_level_api/low_level_api_chat_cpp.py

605 lines
22 KiB
Python
Raw Normal View History

2023-04-03 22:54:46 +02:00
"""
This is an example implementation of main.cpp from llama.cpp
Quirks:
* Its not exactly alike since this port is designed around programmatic I/O
* Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely)
2023-04-04 11:48:48 +02:00
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
* Instruction mode adds its own antiprompt.
You should also still be feeding the model with a "primer" prompt that
shows it the expected format.
2023-04-03 22:54:46 +02:00
"""
2023-05-04 18:33:08 +02:00
import ctypes
import sys
from time import time
2023-05-04 18:33:08 +02:00
from os import cpu_count, path
2023-04-03 22:54:46 +02:00
import llama_cpp
from common import GptParams, gpt_params_parse, gpt_random_prompt
2023-05-06 15:16:58 +02:00
import util
2023-04-03 22:54:46 +02:00
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self, params: GptParams) -> None:
2023-04-03 22:54:46 +02:00
# input args
self.params = params
2023-09-07 17:50:47 -04:00
if self.params.path_session is None:
self.params.path_session = ""
if self.params.antiprompt is None:
self.params.antiprompt = ""
if (self.params.perplexity):
raise NotImplementedError("""************
please use the 'perplexity' tool for perplexity calculations
************""")
if (self.params.embedding):
raise NotImplementedError("""************
please use the 'embedding' tool for embedding calculations
************""")
if (self.params.n_ctx > 2048):
print(f"""warning: model does not support \
context sizes greater than 2048 tokens ({self.params.n_ctx} \
specified) expect poor results""", file=sys.stderr)
if (self.params.seed <= 0):
self.params.seed = int(time())
print(f"seed = {self.params.seed}", file=sys.stderr)
if (self.params.random_prompt):
self.params.prompt = gpt_random_prompt(self.params.seed)
2023-04-03 22:54:46 +02:00
# runtime args
self.input_consumed = 0
self.n_past = 0
2023-05-04 18:33:08 +02:00
self.n_session_consumed = 0
2023-04-03 22:54:46 +02:00
self.first_antiprompt = []
self.remaining_tokens = self.params.n_predict
self.output_echo = self.params.input_echo
2023-05-06 15:16:58 +02:00
self.multibyte_fix = []
2023-04-03 22:54:46 +02:00
# model load
self.lparams = llama_cpp.llama_model_default_params()
self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.lparams.use_mmap = self.params.use_mmap
2023-09-07 17:50:47 -04:00
self.model = llama_cpp.llama_load_model_from_file(
self.params.model.encode("utf8"), self.lparams)
# Context Params.
self.cparams = llama_cpp.llama_context_default_params()
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
2023-05-04 18:33:08 +02:00
if (self.params.ignore_eos):
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
if (len(self.params.lora_adapter) > 0):
if (llama_cpp.llama_apply_lora_from_file(
self.ctx,
2023-05-08 15:27:42 +02:00
self.params.lora_adapter.encode("utf8"),
self.params.lora_base.encode("utf8") if len(self.params.lora_base) > 0 else None,
2023-05-04 18:33:08 +02:00
self.params.n_threads
) != 0):
print("error: failed to apply lora adapter")
return
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
2023-04-28 12:50:30 +02:00
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine the required inference memory per token:
if (self.params.mem_test):
tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
llama_cpp.llama_print_timings(self.ctx)
self.exit()
return
# create internal context
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
# Add a space in front of the first character to match OG llama tokenizer behavior
self.params.prompt = " " + self.params.prompt
# Load prompt file
if (self.params.file):
with open(self.params.file) as f:
self.params.prompt = f.read()
2023-05-04 18:33:08 +02:00
self.session_tokens: list[llama_cpp.llama_token] = []
if (len(self.params.path_session) > 0):
print(f"attempting to load saved session from '{self.params.path_session}'", file=sys.stderr)
if (path.exists(self.params.path_session)):
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
_n_token_count_out = llama_cpp.c_size_t()
2023-05-04 18:33:08 +02:00
if (llama_cpp.llama_load_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
_session_tokens,
self.params.n_ctx,
ctypes.byref(_n_token_count_out)
) != 1):
2023-05-04 18:33:08 +02:00
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
return
_n_token_count_out = _n_token_count_out.value
2023-05-04 18:33:08 +02:00
self.session_tokens = _session_tokens[:_n_token_count_out]
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
else:
print(f"session file does not exist, will create", file=sys.stderr)
# tokenize the prompt
self.embd = []
self.embd_inp = self._tokenize(self.params.prompt)
2023-05-04 18:33:08 +02:00
if (len(self.embd_inp) > self.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
2023-05-04 18:33:08 +02:00
# debug message about similarity of saved session, if applicable
self.n_matching_session_tokens = 0
2023-05-04 18:33:08 +02:00
if len(self.session_tokens) > 0:
for id in self.session_tokens:
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
2023-05-04 18:33:08 +02:00
break
self.n_matching_session_tokens += 1
2023-05-04 18:33:08 +02:00
if self.n_matching_session_tokens >= len(self.embd_inp):
2023-05-04 18:33:08 +02:00
print(f"session file has exact match for prompt!")
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
2023-05-04 18:33:08 +02:00
else:
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
2023-05-04 18:33:08 +02:00
# number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp)
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user
2023-05-04 18:33:08 +02:00
self.antiecho = None
if (self.params.instruct):
self.params.interactive_start = True
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
self.first_antiprompt.append(_ptn)
2023-05-06 15:16:58 +02:00
self.antiecho = util.IterSearch(_ptn)
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
self.params.interactive = True
2023-04-03 22:54:46 +02:00
# determine newline token
2023-04-04 11:48:48 +02:00
self.llama_token_newline = self._tokenize("\n", False)
self.llama_token_eot = self._tokenize(" [end of text]\n", False)
2023-04-04 11:48:48 +02:00
if (self.params.verbose_prompt):
print(f"""
prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
2023-04-03 22:54:46 +02:00
for i in range(len(self.embd_inp)):
2023-09-07 17:50:47 -04:00
print(f"{self.embd_inp[i]} -> '{self.token_to_str(self.embd_inp[i])}'", file=sys.stderr)
2023-04-04 11:48:48 +02:00
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
2023-09-07 17:50:47 -04:00
print(self.token_to_str(self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
2023-04-03 22:54:46 +02:00
if (self.params.interactive):
print("interactive mode on.", file=sys.stderr)
if (len(self.params.antiprompt) > 0):
for antiprompt in self.params.antiprompt:
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr)
if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
2023-05-04 18:33:08 +02:00
print(f"""sampling: repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty},\
presence_penalty = {self.params.presence_penalty},\
frequency_penalty = {self.params.frequency_penalty},\
top_k = {self.params.top_k},\
2023-05-04 18:33:08 +02:00
tfs_z = {self.params.tfs_z},\
top_p = {self.params.top_p},\
2023-05-04 18:33:08 +02:00
typical_p = {self.params.typical_p},\
temp = {self.params.temp},\
mirostat = {self.params.mirostat},\
mirostat_lr = {self.params.mirostat_eta},\
mirostat_ent = {self.params.mirostat_tau},\
generate: n_ctx = {self.n_ctx},\
n_batch = {self.params.n_batch},\
n_predict = {self.params.n_predict},\
n_keep = {self.params.n_keep}
2023-05-04 18:33:08 +02:00
""", file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine antiprompt tokens
for i in self.params.antiprompt:
2023-04-04 11:48:48 +02:00
self.first_antiprompt.append(self._tokenize(i, False))
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
if (params.interactive):
print("""== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\\'.
""", file=sys.stderr)
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_PROMPT)
2023-04-04 11:48:48 +02:00
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
2023-05-06 15:16:58 +02:00
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
2023-04-04 11:48:48 +02:00
return _arr[:_n]
2023-04-03 22:54:46 +02:00
def set_color(self, c):
if (self.params.use_color):
print(c, end="")
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
2023-04-04 11:48:48 +02:00
# generate tokens
2023-04-03 22:54:46 +02:00
def generate(self):
while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1:
2023-04-03 22:54:46 +02:00
# predict
if len(self.embd) > 0:
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.params.n_keep
self.n_past = self.params.n_keep
2023-04-03 22:54:46 +02:00
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
2023-04-03 22:54:46 +02:00
]
self.embd = _insert + self.embd
2023-05-04 18:33:08 +02:00
self.params.path_session = ""
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if self.n_session_consumed < len(self.session_tokens):
for i in range(len(self.embd)):
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
self.session_tokens = self.session_tokens[:self.n_session_consumed]
break
self.n_past += 1
self.n_session_consumed += 1
if self.n_session_consumed >= len(self.session_tokens):
i += 1
break
if i > 0:
self.embd = self.embd[i:]
# evaluate tokens in batches
# embd is typically prepared beforehand to fit within a batch, but not always
#TODO BUG: The batching code causes nonsensical generation
"""for i in range(0, len(self.embd), self.params.n_batch):
n_eval = self.params.n_batch
_arr = (llama_cpp.llama_token * n_eval)(*self.embd[i:i + n_eval])
if llama_cpp.llama_eval(self.ctx, _arr, n_eval, self.n_past, self.params.n_threads) != 0:
print(f"failed to eval")
return
self.n_past += n_eval"""
2023-04-03 22:54:46 +02:00
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
2023-04-03 22:54:46 +02:00
) != 0):
raise Exception("Failed to llama_eval!")
if len(self.embd) > 0 and len(self.params.path_session) > 0:
2023-05-04 18:33:08 +02:00
self.session_tokens.extend(self.embd)
self.n_session_consumed = len(self.session_tokens)
2023-04-03 22:54:46 +02:00
self.n_past += len(self.embd)
self.embd = []
2023-05-04 18:33:08 +02:00
if len(self.embd_inp) <= self.input_consumed: #&& !is_interacting
2023-04-03 22:54:46 +02:00
# out of user input, sample next token
2023-05-04 18:33:08 +02:00
top_k = llama_cpp.llama_n_vocab(self.ctx) if self.params.top_k <= 0 else self.params.top_k
repeat_last_n = self.n_ctx if self.params.repeat_last_n < 0 else self.params.repeat_last_n
# optionally save the session on first sample (for faster prompt loading next time)
if len(self.params.path_session) > 0 and self.need_to_save_session:
self.need_to_save_session = False
llama_cpp.llama_save_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
2023-05-04 18:33:08 +02:00
len(self.session_tokens)
)
id = 0
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.model)
2023-05-04 18:33:08 +02:00
# Apply params.logit_bias map
for key, value in self.params.logit_bias.items():
2023-05-06 22:22:28 +02:00
logits[key] += value
2023-05-04 18:33:08 +02:00
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
# Apply penalties
2023-09-07 17:50:47 -04:00
nl_logit = logits[llama_cpp.llama_token_nl(self.ctx)]
2023-05-04 18:33:08 +02:00
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
llama_cpp.llama_sample_repetition_penalties(
ctx=self.ctx,
candidates=candidates_p,
last_tokens_data = _arr,
penalty_last_n = last_n_repeat,
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
)
# NOT PRESENT IN CURRENT VERSION ?
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
# _arr,
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
2023-05-04 18:33:08 +02:00
if not self.params.penalize_nl:
logits[llama_cpp.llama_token_nl()] = nl_logit
2023-05-06 13:35:50 +02:00
2023-05-04 18:33:08 +02:00
if self.params.temp <= 0:
# Greedy sampling
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
else:
if self.params.mirostat == 1:
mirostat_mu = 2.0 * self.params.mirostat_tau
mirostat_m = 100
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu))
2023-05-04 18:33:08 +02:00
elif self.params.mirostat == 2:
mirostat_mu = 2.0 * self.params.mirostat_tau
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu))
2023-05-04 18:33:08 +02:00
else:
# Temperature sampling
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k, min_keep=llama_cpp.c_size_t(1))
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z), min_keep=llama_cpp.c_size_t(1))
llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p), min_keep=llama_cpp.c_size_t(1))
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p), min_keep=llama_cpp.c_size_t(1))
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
2023-05-04 18:33:08 +02:00
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
# print("`{}`".format(candidates_p.size))
2023-04-03 22:54:46 +02:00
self.last_n_tokens.pop(0)
2023-04-04 11:48:48 +02:00
self.last_n_tokens.append(id)
2023-04-03 22:54:46 +02:00
# replace end of text token with newline token when in interactive mode
2023-09-07 17:50:47 -04:00
if (id == llama_cpp.llama_token_eos(self.ctx) and self.params.interactive and not self.params.instruct):
2023-04-03 22:54:46 +02:00
id = self.llama_token_newline[0]
self.embd.append(id)
if (self.use_antiprompt()):
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
for id in self.first_antiprompt[0]:
self.embd.append(id)
else:
# add it to the context
self.embd.append(id)
2023-04-03 22:54:46 +02:00
# echo this to console
self.output_echo = True
# decrement remaining sampling budget
self.remaining_tokens -= 1
else:
# output to console if input echo is on
self.output_echo = self.params.input_echo
2023-04-03 22:54:46 +02:00
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
2023-04-04 11:48:48 +02:00
self.embd.append(self.embd_inp[self.input_consumed])
2023-04-03 22:54:46 +02:00
self.last_n_tokens.pop(0)
2023-04-04 11:48:48 +02:00
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
2023-04-03 22:54:46 +02:00
self.input_consumed += 1
if len(self.embd) >= self.params.n_batch:
2023-04-03 22:54:46 +02:00
break
# display tokens
if self.output_echo:
for id in self.embd:
2023-05-04 18:33:08 +02:00
if self.antiecho != None:
for r in self.antiecho(id):
yield r
else:
yield id
2023-04-03 22:54:46 +02:00
# reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
2023-04-04 11:48:48 +02:00
# if antiprompt is present, stop
if (self.use_antiprompt()):
if True in [
i == self.last_n_tokens[-len(i):]
for i in self.first_antiprompt
]:
break
2023-04-04 11:48:48 +02:00
# if we are using instruction mode, and we have processed the initial prompt
if (self.params.interactive_start):
2023-04-04 11:48:48 +02:00
break
2023-04-03 22:54:46 +02:00
# end of text token
2023-09-07 17:50:47 -04:00
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(self.ctx):
if (not self.params.instruct):
for i in self.llama_token_eot:
yield i
2023-05-04 18:33:08 +02:00
break
2023-04-03 22:54:46 +02:00
# respect n_predict even if antiprompt is present
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
# If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self.params.instruct:
self.embd_inp += self.first_antiprompt[0]
self.n_remain = self.params.n_predict
2023-04-03 22:54:46 +02:00
break
self.params.interactive_start = False
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.exit()
def exit(self):
llama_cpp.llama_free(self.ctx)
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
2023-09-07 17:50:47 -04:00
def token_to_str(self, token_id: int) -> bytes:
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece(
2023-09-07 17:50:47 -04:00
self.model, llama_cpp.llama_token(token_id), buffer, size)
assert n <= size
return bytes(buffer[:n])
2023-04-04 11:48:48 +02:00
# return past text
2023-04-03 22:54:46 +02:00
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
2023-09-07 17:50:47 -04:00
yield self.token_to_str(id).decode("utf8", errors="ignore")
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write input
2023-04-03 22:54:46 +02:00
def input(self, prompt: str):
if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
2023-04-04 11:48:48 +02:00
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt)
if (self.params.instruct):
2023-04-04 11:48:48 +02:00
self.embd_inp += self.inp_suffix
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write output
2023-04-03 22:54:46 +02:00
def output(self):
self.remaining_tokens = self.params.n_predict
2023-04-03 22:54:46 +02:00
for id in self.generate():
2023-09-07 17:50:47 -04:00
cur_char = self.token_to_str(id)
2023-05-06 15:16:58 +02:00
# Add remainder of missing bytes
if None in self.multibyte_fix:
self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char
# Return completed utf char
if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix:
yield (b"".join(self.multibyte_fix)).decode("utf8")
self.multibyte_fix = []
continue
# Contains multi-byte UTF8
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if pattern & int.from_bytes(cur_char, 'little') == pattern:
2023-05-06 15:16:58 +02:00
self.multibyte_fix = [cur_char] + ([None] * (num-1))
# Stop incomplete bytes from passing
if len(self.multibyte_fix) > 0:
continue
yield cur_char.decode("utf8")
2023-04-03 22:54:46 +02:00
# read user input
def read_input(self):
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
# interactive mode
def interact(self):
for i in self.output():
print(i,end="",flush=True)
self.params.input_echo = False
# Using string instead of tokens to check for antiprompt,
# It is more reliable than tokens for interactive mode.
generated_str = ""
while self.params.interactive:
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
print('\n> ', end="")
self.input(self.read_input())
else:
print(self.params.input_prefix, end="")
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}")
print(self.params.input_suffix,end="")
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
try:
for i in self.output():
print(i,end="",flush=True)
generated_str += i
for ap in self.params.antiprompt:
if generated_str.endswith(ap):
raise KeyboardInterrupt
except KeyboardInterrupt:
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
print(self.params.fix_prefix,end="")
self.input(self.params.fix_prefix)
2023-04-03 22:54:46 +02:00
if __name__ == "__main__":
from datetime import datetime
USER_NAME="User"
AI_NAME="ChatLLaMa"
2023-04-04 11:48:48 +02:00
2023-04-03 22:54:46 +02:00
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
2023-04-03 22:54:46 +02:00
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
{USER_NAME}: Hello, {AI_NAME}!
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
{USER_NAME}: What time is it?
{AI_NAME}: It is {time_now.strftime("%H:%M")}.
{USER_NAME}: What year is it?
{AI_NAME}: We are in {time_now.strftime("%Y")}.
{USER_NAME}: What is a cat?
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color.
{AI_NAME}: Blue
{USER_NAME}: """
2023-05-04 18:33:08 +02:00
params = gpt_params_parse()
if params.prompt is None and params.file is None:
params.prompt = prompt
with LLaMAInteract(params) as m:
m.interact()