llama.cpp/examples/low_level_api/low_level_api_chat_cpp.py

566 lines
21 KiB
Python
Raw Normal View History

2023-04-03 22:54:46 +02:00
"""
This is an example implementation of main.cpp from llama.cpp
Quirks:
* Its not exactly alike since this port is designed around programmatic I/O
* Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely)
2023-04-04 11:48:48 +02:00
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
* Instruction mode adds its own antiprompt.
You should also still be feeding the model with a "primer" prompt that
shows it the expected format.
2023-04-03 22:54:46 +02:00
"""
2023-05-04 18:33:08 +02:00
import ctypes
import sys
from time import time
2023-05-04 18:33:08 +02:00
from os import cpu_count, path
2023-04-03 22:54:46 +02:00
import llama_cpp
from common import GptParams, gpt_params_parse, gpt_random_prompt
2023-05-06 15:16:58 +02:00
import util
2023-04-03 22:54:46 +02:00
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self, params: GptParams) -> None:
2023-04-03 22:54:46 +02:00
# input args
self.params = params
if (self.params.perplexity):
raise NotImplementedError("""************
please use the 'perplexity' tool for perplexity calculations
************""")
if (self.params.embedding):
raise NotImplementedError("""************
please use the 'embedding' tool for embedding calculations
************""")
if (self.params.n_ctx > 2048):
print(f"""warning: model does not support \
context sizes greater than 2048 tokens ({self.params.n_ctx} \
specified) expect poor results""", file=sys.stderr)
if (self.params.seed <= 0):
self.params.seed = int(time())
print(f"seed = {self.params.seed}", file=sys.stderr)
if (self.params.random_prompt):
self.params.prompt = gpt_random_prompt(self.params.seed)
2023-04-03 22:54:46 +02:00
# runtime args
self.input_consumed = 0
self.n_past = 0
2023-05-04 18:33:08 +02:00
self.n_session_consumed = 0
2023-04-03 22:54:46 +02:00
self.first_antiprompt = []
self.remaining_tokens = self.params.n_predict
self.output_echo = self.params.input_echo
2023-05-06 15:16:58 +02:00
self.multibyte_fix = []
2023-04-03 22:54:46 +02:00
# model load
self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.lparams.use_mmap = self.params.use_mmap
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
2023-05-04 18:33:08 +02:00
if (self.params.ignore_eos):
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
if (len(self.params.lora_adapter) > 0):
if (llama_cpp.llama_apply_lora_from_file(
self.ctx,
2023-05-08 15:27:42 +02:00
self.params.lora_adapter.encode("utf8"),
self.params.lora_base.encode("utf8") if len(self.params.lora_base) > 0 else None,
2023-05-04 18:33:08 +02:00
self.params.n_threads
) != 0):
print("error: failed to apply lora adapter")
return
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
2023-04-28 12:50:30 +02:00
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine the required inference memory per token:
if (self.params.mem_test):
tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
llama_cpp.llama_print_timings(self.ctx)
self.exit()
return
# create internal context
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
# Add a space in front of the first character to match OG llama tokenizer behavior
self.params.prompt = " " + self.params.prompt
# Load prompt file
if (self.params.file):
with open(self.params.file) as f:
self.params.prompt = f.read()
2023-05-04 18:33:08 +02:00
self.session_tokens: list[llama_cpp.llama_token] = []
if (len(self.params.path_session) > 0):
print(f"attempting to load saved session from '{self.params.path_session}'", file=sys.stderr)
if (path.exists(self.params.path_session)):
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
_n_token_count_out = llama_cpp.c_size_t()
2023-05-04 18:33:08 +02:00
if (llama_cpp.llama_load_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
_session_tokens,
self.params.n_ctx,
ctypes.byref(_n_token_count_out)
) != 1):
2023-05-04 18:33:08 +02:00
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
return
_n_token_count_out = _n_token_count_out.value
2023-05-04 18:33:08 +02:00
self.session_tokens = _session_tokens[:_n_token_count_out]
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
else:
print(f"session file does not exist, will create", file=sys.stderr)
# tokenize the prompt
self.embd = []
self.embd_inp = self._tokenize(self.params.prompt)
2023-05-04 18:33:08 +02:00
if (len(self.embd_inp) > self.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
2023-05-04 18:33:08 +02:00
# debug message about similarity of saved session, if applicable
self.n_matching_session_tokens = 0
2023-05-04 18:33:08 +02:00
if len(self.session_tokens) > 0:
for id in self.session_tokens:
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
2023-05-04 18:33:08 +02:00
break
self.n_matching_session_tokens += 1
2023-05-04 18:33:08 +02:00
if self.n_matching_session_tokens >= len(self.embd_inp):
2023-05-04 18:33:08 +02:00
print(f"session file has exact match for prompt!")
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
2023-05-04 18:33:08 +02:00
else:
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
2023-05-04 18:33:08 +02:00
# number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp)
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user
2023-05-04 18:33:08 +02:00
self.antiecho = None
if (self.params.instruct):
self.params.interactive_start = True
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
self.first_antiprompt.append(_ptn)
2023-05-06 15:16:58 +02:00
self.antiecho = util.IterSearch(_ptn)
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
self.params.interactive = True
2023-04-03 22:54:46 +02:00
# determine newline token
2023-04-04 11:48:48 +02:00
self.llama_token_newline = self._tokenize("\n", False)
self.llama_token_eot = self._tokenize(" [end of text]\n", False)
2023-04-04 11:48:48 +02:00
if (self.params.verbose_prompt):
print(f"""
prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
2023-04-03 22:54:46 +02:00
for i in range(len(self.embd_inp)):
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
2023-04-04 11:48:48 +02:00
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
2023-04-03 22:54:46 +02:00
if (self.params.interactive):
print("interactive mode on.", file=sys.stderr)
if (len(self.params.antiprompt) > 0):
for antiprompt in self.params.antiprompt:
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr)
if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
2023-05-04 18:33:08 +02:00
print(f"""sampling: repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty},\
presence_penalty = {self.params.presence_penalty},\
frequency_penalty = {self.params.frequency_penalty},\
top_k = {self.params.top_k},\
2023-05-04 18:33:08 +02:00
tfs_z = {self.params.tfs_z},\
top_p = {self.params.top_p},\
2023-05-04 18:33:08 +02:00
typical_p = {self.params.typical_p},\
temp = {self.params.temp},\
mirostat = {self.params.mirostat},\
mirostat_lr = {self.params.mirostat_eta},\
mirostat_ent = {self.params.mirostat_tau},\
generate: n_ctx = {self.n_ctx},\
n_batch = {self.params.n_batch},\
n_predict = {self.params.n_predict},\
n_keep = {self.params.n_keep}
2023-05-04 18:33:08 +02:00
""", file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine antiprompt tokens
for i in self.params.antiprompt:
2023-04-04 11:48:48 +02:00
self.first_antiprompt.append(self._tokenize(i, False))
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
if (params.interactive):
print("""== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\\'.
""", file=sys.stderr)
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_PROMPT)
2023-04-04 11:48:48 +02:00
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
2023-05-06 15:16:58 +02:00
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
2023-04-26 14:45:51 +02:00
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
2023-04-04 11:48:48 +02:00
return _arr[:_n]
2023-04-03 22:54:46 +02:00
def set_color(self, c):
if (self.params.use_color):
print(c, end="")
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
2023-04-04 11:48:48 +02:00
# generate tokens
2023-04-03 22:54:46 +02:00
def generate(self):
while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1:
2023-04-03 22:54:46 +02:00
# predict
if len(self.embd) > 0:
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.params.n_keep
self.n_past = self.params.n_keep
2023-04-03 22:54:46 +02:00
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
2023-04-03 22:54:46 +02:00
]
self.embd = _insert + self.embd
2023-05-04 18:33:08 +02:00
self.params.path_session = ""
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if self.n_session_consumed < len(self.session_tokens):
for i in range(len(self.embd)):
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
self.session_tokens = self.session_tokens[:self.n_session_consumed]
break
self.n_past += 1
self.n_session_consumed += 1
if self.n_session_consumed >= len(self.session_tokens):
i += 1
break
if i > 0:
self.embd = self.embd[i:]
# evaluate tokens in batches
# embd is typically prepared beforehand to fit within a batch, but not always
#TODO BUG: The batching code causes nonsensical generation
"""for i in range(0, len(self.embd), self.params.n_batch):
n_eval = self.params.n_batch
_arr = (llama_cpp.llama_token * n_eval)(*self.embd[i:i + n_eval])
if llama_cpp.llama_eval(self.ctx, _arr, n_eval, self.n_past, self.params.n_threads) != 0:
print(f"failed to eval")
return
self.n_past += n_eval"""
2023-04-03 22:54:46 +02:00
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
2023-04-03 22:54:46 +02:00
) != 0):
raise Exception("Failed to llama_eval!")
if len(self.embd) > 0 and len(self.params.path_session) > 0:
2023-05-04 18:33:08 +02:00
self.session_tokens.extend(self.embd)
self.n_session_consumed = len(self.session_tokens)
2023-04-03 22:54:46 +02:00
self.n_past += len(self.embd)
self.embd = []
2023-05-04 18:33:08 +02:00
if len(self.embd_inp) <= self.input_consumed: #&& !is_interacting
2023-04-03 22:54:46 +02:00
# out of user input, sample next token
2023-05-04 18:33:08 +02:00
top_k = llama_cpp.llama_n_vocab(self.ctx) if self.params.top_k <= 0 else self.params.top_k
repeat_last_n = self.n_ctx if self.params.repeat_last_n < 0 else self.params.repeat_last_n
# optionally save the session on first sample (for faster prompt loading next time)
if len(self.params.path_session) > 0 and self.need_to_save_session:
self.need_to_save_session = False
llama_cpp.llama_save_session_file(
self.ctx,
self.params.path_session.encode("utf8"),
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
2023-05-04 18:33:08 +02:00
len(self.session_tokens)
)
id = 0
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
# Apply params.logit_bias map
for key, value in self.params.logit_bias.items():
2023-05-06 22:22:28 +02:00
logits[key] += value
2023-05-04 18:33:08 +02:00
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
# Apply penalties
nl_logit = logits[llama_cpp.llama_token_nl()]
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
_arr,
2023-05-06 13:35:50 +02:00
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
2023-05-04 18:33:08 +02:00
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
_arr,
2023-05-06 13:35:50 +02:00
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
2023-05-04 18:33:08 +02:00
if not self.params.penalize_nl:
logits[llama_cpp.llama_token_nl()] = nl_logit
2023-05-06 13:35:50 +02:00
2023-05-04 18:33:08 +02:00
if self.params.temp <= 0:
# Greedy sampling
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
else:
if self.params.mirostat == 1:
mirostat_mu = 2.0 * self.params.mirostat_tau
mirostat_m = 100
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu))
2023-05-04 18:33:08 +02:00
elif self.params.mirostat == 2:
mirostat_mu = 2.0 * self.params.mirostat_tau
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu))
2023-05-04 18:33:08 +02:00
else:
# Temperature sampling
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k)
2023-05-06 13:35:50 +02:00
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z))
llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p))
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p))
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
2023-05-04 18:33:08 +02:00
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
# print("`{}`".format(candidates_p.size))
2023-04-03 22:54:46 +02:00
self.last_n_tokens.pop(0)
2023-04-04 11:48:48 +02:00
self.last_n_tokens.append(id)
2023-04-03 22:54:46 +02:00
# replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
2023-04-03 22:54:46 +02:00
id = self.llama_token_newline[0]
if (self.use_antiprompt()):
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
2023-04-03 22:54:46 +02:00
# add it to the context
2023-04-04 11:48:48 +02:00
self.embd.append(id)
2023-04-03 22:54:46 +02:00
# echo this to console
self.output_echo = True
# decrement remaining sampling budget
self.remaining_tokens -= 1
else:
# output to console if input echo is on
self.output_echo = self.params.input_echo
2023-04-03 22:54:46 +02:00
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
2023-04-04 11:48:48 +02:00
self.embd.append(self.embd_inp[self.input_consumed])
2023-04-03 22:54:46 +02:00
self.last_n_tokens.pop(0)
2023-04-04 11:48:48 +02:00
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
2023-04-03 22:54:46 +02:00
self.input_consumed += 1
if len(self.embd) >= self.params.n_batch:
2023-04-03 22:54:46 +02:00
break
# display tokens
if self.output_echo:
for id in self.embd:
2023-05-04 18:33:08 +02:00
if self.antiecho != None:
for r in self.antiecho(id):
yield r
else:
yield id
2023-04-03 22:54:46 +02:00
# reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
2023-04-04 11:48:48 +02:00
# if antiprompt is present, stop
if (self.use_antiprompt()):
if True in [
i == self.last_n_tokens[-len(i):]
for i in self.first_antiprompt
]:
break
2023-04-04 11:48:48 +02:00
# if we are using instruction mode, and we have processed the initial prompt
if (self.params.interactive_start):
2023-04-04 11:48:48 +02:00
break
2023-04-03 22:54:46 +02:00
# end of text token
2023-04-03 22:54:46 +02:00
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
if (not self.params.instruct):
for i in self.llama_token_eot:
yield i
2023-05-04 18:33:08 +02:00
break
2023-04-03 22:54:46 +02:00
# respect n_predict even if antiprompt is present
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
# If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self.params.instruct:
self.embd_inp += self.first_antiprompt[0]
self.n_remain = self.params.n_predict
2023-04-03 22:54:46 +02:00
break
self.params.interactive_start = False
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.exit()
def exit(self):
llama_cpp.llama_free(self.ctx)
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
2023-04-04 11:48:48 +02:00
# return past text
2023-04-03 22:54:46 +02:00
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
2023-05-06 15:16:58 +02:00
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write input
2023-04-03 22:54:46 +02:00
def input(self, prompt: str):
if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
2023-04-04 11:48:48 +02:00
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt)
if (self.params.instruct):
2023-04-04 11:48:48 +02:00
self.embd_inp += self.inp_suffix
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write output
2023-04-03 22:54:46 +02:00
def output(self):
self.remaining_tokens = self.params.n_predict
2023-04-03 22:54:46 +02:00
for id in self.generate():
2023-05-06 15:16:58 +02:00
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
# Add remainder of missing bytes
if None in self.multibyte_fix:
self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char
# Return completed utf char
if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix:
yield (b"".join(self.multibyte_fix)).decode("utf8")
self.multibyte_fix = []
continue
# Contains multi-byte UTF8
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if pattern & int.from_bytes(cur_char) == pattern:
self.multibyte_fix = [cur_char] + ([None] * (num-1))
# Stop incomplete bytes from passing
if len(self.multibyte_fix) > 0:
continue
yield cur_char.decode("utf8")
2023-04-03 22:54:46 +02:00
# read user input
def read_input(self):
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
# interactive mode
def interact(self):
for i in self.output():
print(i,end="",flush=True)
self.params.input_echo = False
while self.params.interactive:
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
print('\n> ', end="")
self.input(self.read_input())
else:
print(self.params.input_prefix, end="")
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}")
print(self.params.input_suffix,end="")
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
try:
for i in self.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
2023-05-06 15:16:58 +02:00
self.set_color(util.CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
print(self.params.fix_prefix,end="")
self.input(self.params.fix_prefix)
2023-04-03 22:54:46 +02:00
if __name__ == "__main__":
from datetime import datetime
USER_NAME="User"
AI_NAME="ChatLLaMa"
2023-04-04 11:48:48 +02:00
2023-04-03 22:54:46 +02:00
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
{USER_NAME}: Hello, {AI_NAME}!
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
{USER_NAME}: What time is it?
{AI_NAME}: It is {time_now.strftime("%H:%M")}.
{USER_NAME}: What year is it?
{AI_NAME}: We are in {time_now.strftime("%Y")}.
{USER_NAME}: What is a cat?
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color.
{AI_NAME}: Blue
{USER_NAME}:"""
2023-05-04 18:33:08 +02:00
params = gpt_params_parse()
with LLaMAInteract(params) as m:
m.interact()