""" This is an example implementation of main.cpp from llama.cpp Quirks: * Its not exactly alike since this port is designed around programmatic I/O * Input is always echoed if on, so it should be turned off when using "input()" * The first antiprompt should be the userprompt like "\nUser:", because its added when n_predict is reached (aka generation ended prematurely) * n_predict can be set to -1 for unlimited length responses (or just a really high value) * Instruction mode adds its own antiprompt. You should also still be feeding the model with a "primer" prompt that shows it the expected format. """ import ctypes import sys from time import time from os import cpu_count, path import llama_cpp from common import GptParams, gpt_params_parse, gpt_random_prompt import util # A LLaMA interactive session class LLaMAInteract: def __init__(self, params: GptParams) -> None: # input args self.params = params if (self.params.perplexity): raise NotImplementedError("""************ please use the 'perplexity' tool for perplexity calculations ************""") if (self.params.embedding): raise NotImplementedError("""************ please use the 'embedding' tool for embedding calculations ************""") if (self.params.n_ctx > 2048): print(f"""warning: model does not support \ context sizes greater than 2048 tokens ({self.params.n_ctx} \ specified) expect poor results""", file=sys.stderr) if (self.params.seed <= 0): self.params.seed = int(time()) print(f"seed = {self.params.seed}", file=sys.stderr) if (self.params.random_prompt): self.params.prompt = gpt_random_prompt(self.params.seed) # runtime args self.input_consumed = 0 self.n_past = 0 self.n_session_consumed = 0 self.first_antiprompt = [] self.remaining_tokens = self.params.n_predict self.output_echo = self.params.input_echo self.multibyte_fix = [] # model load self.lparams = llama_cpp.llama_context_default_params() self.lparams.n_ctx = self.params.n_ctx self.lparams.n_parts = self.params.n_parts self.lparams.seed = self.params.seed self.lparams.memory_f16 = self.params.memory_f16 self.lparams.use_mlock = self.params.use_mlock self.lparams.use_mmap = self.params.use_mmap self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams) if (not self.ctx): raise RuntimeError(f"error: failed to load model '{self.params.model}'") if (self.params.ignore_eos): self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf") if (len(self.params.lora_adapter) > 0): if (llama_cpp.llama_apply_lora_from_file( self.ctx, self.params.lora_adapter, self.params.lora_base if len(self.params.lora_base) > 0 else None, self.params.n_threads ) != 0): print("error: failed to apply lora adapter") return print(file=sys.stderr) print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ | {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) # determine the required inference memory per token: if (self.params.mem_test): tmp = [0, 1, 2, 3] llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) llama_cpp.llama_print_timings(self.ctx) self.exit() return # create internal context self.n_ctx = llama_cpp.llama_n_ctx(self.ctx) # Add a space in front of the first character to match OG llama tokenizer behavior self.params.prompt = " " + self.params.prompt # Load prompt file if (self.params.file): with open(self.params.file) as f: self.params.prompt = f.read() self.session_tokens: list[llama_cpp.llama_token] = [] if (len(self.params.path_session) > 0): print(f"attempting to load saved session from '{self.params.path_session}'", file=sys.stderr) if (path.exists(self.params.path_session)): _session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))() _n_token_count_out = llama_cpp.c_int() if (llama_cpp.llama_load_session_file( self.ctx, self.params.path_session.encode("utf8"), _session_tokens, self.params.n_ctx, ctypes.byref(_n_token_count_out) ) != 0): print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr) return self.session_tokens = _session_tokens[:_n_token_count_out] print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr) else: print(f"session file does not exist, will create", file=sys.stderr) # tokenize the prompt self.embd = [] self.embd_inp = self._tokenize(self.params.prompt) if (len(self.embd_inp) > self.n_ctx - 4): raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})") # debug message about similarity of saved session, if applicable n_matching_session_tokens = 0 if len(self.session_tokens) > 0: for id in self.session_tokens: if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]: break n_matching_session_tokens += 1 if n_matching_session_tokens >= len(self.embd_inp): print(f"session file has exact match for prompt!") elif n_matching_session_tokens < (len(self.embd_inp) / 2): print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated") else: print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt") # number of tokens to keep when resetting context if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct): self.params.n_keep = len(self.embd_inp) self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix) self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False) # in instruct mode, we inject a prefix and a suffix to each input by the user self.antiecho = None if (self.params.instruct): self.params.interactive_start = True _ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False) self.first_antiprompt.append(_ptn) self.antiecho = util.IterSearch(_ptn) # enable interactive mode if reverse prompt or interactive start is specified if (len(self.params.antiprompt) != 0 or self.params.interactive_start): self.params.interactive = True # determine newline token self.llama_token_newline = self._tokenize("\n", False) self.llama_token_eot = self._tokenize(" [end of text]\n", False) if (self.params.verbose_prompt): print(f""" prompt: '{self.params.prompt}' number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr) for i in range(len(self.embd_inp)): print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr) if (self.params.n_keep > 0): print("static prompt based on n_keep: '") for i in range(self.params.n_keep): print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr) print("'", file=sys.stderr) print(file=sys.stderr) if (self.params.interactive): print("interactive mode on.", file=sys.stderr) if (len(self.params.antiprompt) > 0): for antiprompt in self.params.antiprompt: print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr) if len(self.params.input_prefix) > 0: print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr) print(f"""sampling: repeat_last_n = {self.params.repeat_last_n},\ repeat_penalty = {self.params.repeat_penalty},\ presence_penalty = {self.params.presence_penalty},\ frequency_penalty = {self.params.frequency_penalty},\ top_k = {self.params.top_k},\ tfs_z = {self.params.tfs_z},\ top_p = {self.params.top_p},\ typical_p = {self.params.typical_p},\ temp = {self.params.temp},\ mirostat = {self.params.mirostat},\ mirostat_lr = {self.params.mirostat_eta},\ mirostat_ent = {self.params.mirostat_tau},\ generate: n_ctx = {self.n_ctx},\ n_batch = {self.params.n_batch},\ n_predict = {self.params.n_predict},\ n_keep = {self.params.n_keep} """, file=sys.stderr) # determine antiprompt tokens for i in self.params.antiprompt: self.first_antiprompt.append(self._tokenize(i, False)) self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices if (params.interactive): print("""== Running in interactive mode. == - Press Ctrl+C to interject at any time. - Press Return to return control to LLaMa. - If you want to submit another line, end your input in '\\'. """, file=sys.stderr) self.set_color(util.CONSOLE_COLOR_PROMPT) self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4) # tokenize a prompt def _tokenize(self, prompt, bos=True): _arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))() _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos) return _arr[:_n] def set_color(self, c): if (self.params.use_color): print(c, end="") def use_antiprompt(self): return len(self.first_antiprompt) > 0 # generate tokens def generate(self): while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1: # predict if len(self.embd) > 0: # infinite text generation via context swapping # if we run out of context: # - take the n_keep first tokens from the original prompt (via n_past) # - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch if (self.n_past + len(self.embd) > self.n_ctx): n_left = self.n_past - self.params.n_keep self.n_past = self.params.n_keep # insert n_left/2 tokens at the start of embd from last_n_tokens _insert = self.last_n_tokens[ self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd) ] self.embd = _insert + self.embd self.params.path_session = "" # try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) if self.n_session_consumed < len(self.session_tokens): for i in range(len(self.embd)): if self.embd[i] != self.session_tokens[self.n_session_consumed]: self.session_tokens = self.session_tokens[:self.n_session_consumed] break self.n_past += 1 self.n_session_consumed += 1 if self.n_session_consumed >= len(self.session_tokens): i += 1 break if i > 0: self.embd = self.embd[i:] # evaluate tokens in batches # embd is typically prepared beforehand to fit within a batch, but not always #TODO BUG: The batching code causes nonsensical generation """for i in range(0, len(self.embd), self.params.n_batch): n_eval = self.params.n_batch _arr = (llama_cpp.llama_token * n_eval)(*self.embd[i:i + n_eval]) if llama_cpp.llama_eval(self.ctx, _arr, n_eval, self.n_past, self.params.n_threads) != 0: print(f"failed to eval") return self.n_past += n_eval""" if (llama_cpp.llama_eval( self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads ) != 0): raise Exception("Failed to llama_eval!") if len(self.embd) > 0 and not len(self.params.path_session) > 0: self.session_tokens.extend(self.embd) self.n_session_consumed = len(self.session_tokens) self.n_past += len(self.embd) self.embd = [] if len(self.embd_inp) <= self.input_consumed: #&& !is_interacting # out of user input, sample next token top_k = llama_cpp.llama_n_vocab(self.ctx) if self.params.top_k <= 0 else self.params.top_k repeat_last_n = self.n_ctx if self.params.repeat_last_n < 0 else self.params.repeat_last_n # optionally save the session on first sample (for faster prompt loading next time) if len(self.params.path_session) > 0 and self.need_to_save_session: self.need_to_save_session = False llama_cpp.llama_save_session_file( self.ctx, self.params.path_session.encode("utf8"), self.session_tokens, len(self.session_tokens) ) id = 0 logits = llama_cpp.llama_get_logits(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.ctx) # Apply params.logit_bias map for key, value in self.params.logit_bias.items(): logits[key] += llama_cpp.c_float(value) _arr = (llama_cpp.llama_token_data * n_vocab)(*[ llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) for token_id in range(n_vocab) ]) candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False)) # Apply penalties nl_logit = logits[llama_cpp.llama_token_nl()] last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx) _arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:]) llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, _arr, last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty)) llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, _arr, last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty)) if not self.params.penalize_nl: logits[llama_cpp.llama_token_nl()] = nl_logit if self.params.temp <= 0: # Greedy sampling id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p) else: if self.params.mirostat == 1: mirostat_mu = 2.0 * self.params.mirostat_tau mirostat_m = 100 llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu)) elif self.params.mirostat == 2: mirostat_mu = 2.0 * self.params.mirostat_tau llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu)) else: # Temperature sampling llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k) llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z)) llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p)) llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p)) llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) id = llama_cpp.llama_sample_token(self.ctx, candidates_p) # print("`{}`".format(candidates_p.size)) self.last_n_tokens.pop(0) self.last_n_tokens.append(id) # replace end of text token with newline token when in interactive mode if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct): id = self.llama_token_newline[0] if (self.use_antiprompt()): # tokenize and inject first reverse prompt self.embd_inp += self.first_antiprompt[0] # add it to the context self.embd.append(id) # echo this to console self.output_echo = True # decrement remaining sampling budget self.remaining_tokens -= 1 else: # output to console if input echo is on self.output_echo = self.params.input_echo # some user input remains from prompt or interaction, forward it to processing while len(self.embd_inp) > self.input_consumed: self.embd.append(self.embd_inp[self.input_consumed]) self.last_n_tokens.pop(0) self.last_n_tokens.append(self.embd_inp[self.input_consumed]) self.input_consumed += 1 if len(self.embd) >= self.params.n_batch: break # display tokens if self.output_echo: for id in self.embd: if self.antiecho != None: for r in self.antiecho(id): yield r else: yield id # reset color to default if we there is no pending user input if (self.params.input_echo and len(self.embd_inp) == self.input_consumed): self.set_color(util.CONSOLE_COLOR_DEFAULT) if (self.params.interactive and len(self.embd_inp) <= self.input_consumed): # if antiprompt is present, stop if (self.use_antiprompt()): if True in [ i == self.last_n_tokens[-len(i):] for i in self.first_antiprompt ]: break # if we are using instruction mode, and we have processed the initial prompt if (self.params.interactive_start): break # end of text token if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): if (not self.params.instruct): for i in self.llama_token_eot: yield i break # respect n_predict even if antiprompt is present if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1): # If we arent in instruction mode, fix the current generation by appending the antiprompt. # Makes it so if chat ends prematurely you dont append the AI's text etc. if not self.params.instruct: self.embd_inp += self.first_antiprompt[0] self.n_remain = self.params.n_predict break self.params.interactive_start = False def __enter__(self): return self def __exit__(self, type, value, tb): self.exit() def exit(self): llama_cpp.llama_free(self.ctx) self.set_color(util.CONSOLE_COLOR_DEFAULT) # return past text def past(self): for id in self.last_n_tokens[-self.n_past:]: yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore") # write input def input(self, prompt: str): if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix): self.embd_inp += self.inp_prefix self.embd_inp += self._tokenize(prompt) if (self.params.instruct): self.embd_inp += self.inp_suffix # write output def output(self): self.remaining_tokens = self.params.n_predict for id in self.generate(): cur_char = llama_cpp.llama_token_to_str(self.ctx, id) # Add remainder of missing bytes if None in self.multibyte_fix: self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char # Return completed utf char if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix: yield (b"".join(self.multibyte_fix)).decode("utf8") self.multibyte_fix = [] continue # Contains multi-byte UTF8 for num, pattern in [(2, 192), (3, 224), (4, 240)]: # Bitwise AND check if pattern & int.from_bytes(cur_char) == pattern: self.multibyte_fix = [cur_char] + ([None] * (num-1)) # Stop incomplete bytes from passing if len(self.multibyte_fix) > 0: continue yield cur_char.decode("utf8") # read user input def read_input(self): out = "" while (t := input()).endswith("\\"): out += t[:-1] + "\n" return out + t + "\n" # interactive mode def interact(self): for i in self.output(): print(i,end="",flush=True) self.params.input_echo = False while self.params.interactive: self.set_color(util.CONSOLE_COLOR_USER_INPUT) if (self.params.instruct): print('\n> ', end="") self.input(self.read_input()) else: print(self.params.input_prefix, end="") self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}") print(self.params.input_suffix,end="") self.set_color(util.CONSOLE_COLOR_DEFAULT) try: for i in self.output(): print(i,end="",flush=True) except KeyboardInterrupt: self.set_color(util.CONSOLE_COLOR_DEFAULT) if not self.params.instruct: print(self.params.fix_prefix,end="") self.input(self.params.fix_prefix) if __name__ == "__main__": from datetime import datetime USER_NAME="User" AI_NAME="ChatLLaMa" time_now = datetime.now() prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}. {AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision. There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other. The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. The transcript only includes text, it does not include markup like HTML and Markdown. {USER_NAME}: Hello, {AI_NAME}! {AI_NAME}: Hello {USER_NAME}! How may I help you today? {USER_NAME}: What time is it? {AI_NAME}: It is {time_now.strftime("%H:%M")}. {USER_NAME}: What year is it? {AI_NAME}: We are in {time_now.strftime("%Y")}. {USER_NAME}: What is a cat? {AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. {USER_NAME}: Name a color. {AI_NAME}: Blue {USER_NAME}:""" params = gpt_params_parse() with LLaMAInteract(params) as m: m.interact()