2023-04-03 22:54:46 +02:00
"""
This is an example implementation of main . cpp from llama . cpp
Quirks :
* Its not exactly alike since this port is designed around programmatic I / O
* Input is always echoed if on , so it should be turned off when using " input() "
* The first antiprompt should be the userprompt like " \n User: " ,
because its added when n_predict is reached ( aka generation ended prematurely )
2023-04-04 11:48:48 +02:00
* n_predict can be set to - 1 for unlimited length responses ( or just a really high value )
2023-04-04 16:18:26 +02:00
* Instruction mode adds its own antiprompt .
You should also still be feeding the model with a " primer " prompt that
shows it the expected format .
2023-04-03 22:54:46 +02:00
"""
2023-05-04 18:33:08 +02:00
import ctypes
2023-04-06 15:30:57 +02:00
import sys
from time import time
2023-05-04 18:33:08 +02:00
from os import cpu_count , path
2023-04-06 15:30:57 +02:00
2023-04-03 22:54:46 +02:00
import llama_cpp
2023-04-06 15:30:57 +02:00
from common import GptParams , gpt_params_parse , gpt_random_prompt
2023-05-06 15:16:58 +02:00
import util
2023-04-10 16:35:38 +02:00
2023-04-03 22:54:46 +02:00
# A LLaMA interactive session
class LLaMAInteract :
2023-04-06 15:30:57 +02:00
def __init__ ( self , params : GptParams ) - > None :
2023-04-03 22:54:46 +02:00
# input args
2023-04-06 15:30:57 +02:00
self . params = params
2023-09-07 17:50:47 -04:00
if self . params . path_session is None :
self . params . path_session = " "
if self . params . antiprompt is None :
self . params . antiprompt = " "
2023-04-06 15:30:57 +02:00
if ( self . params . perplexity ) :
raise NotImplementedError ( """ ************
please use the ' perplexity ' tool for perplexity calculations
* * * * * * * * * * * * """ )
if ( self . params . embedding ) :
raise NotImplementedError ( """ ************
please use the ' embedding ' tool for embedding calculations
* * * * * * * * * * * * """ )
if ( self . params . n_ctx > 2048 ) :
print ( f """ warning: model does not support \
context sizes greater than 2048 tokens ( { self . params . n_ctx } \
specified ) expect poor results """ , file=sys.stderr)
if ( self . params . seed < = 0 ) :
self . params . seed = int ( time ( ) )
print ( f " seed = { self . params . seed } " , file = sys . stderr )
if ( self . params . random_prompt ) :
self . params . prompt = gpt_random_prompt ( self . params . seed )
2023-04-03 22:54:46 +02:00
# runtime args
self . input_consumed = 0
self . n_past = 0
2023-05-04 18:33:08 +02:00
self . n_session_consumed = 0
2023-04-03 22:54:46 +02:00
self . first_antiprompt = [ ]
2023-04-06 15:30:57 +02:00
self . remaining_tokens = self . params . n_predict
self . output_echo = self . params . input_echo
2023-05-06 15:16:58 +02:00
self . multibyte_fix = [ ]
2023-04-03 22:54:46 +02:00
# model load
2024-01-15 09:46:35 -06:00
self . lparams = llama_cpp . llama_model_default_params ( )
2023-04-06 15:30:57 +02:00
self . lparams . n_ctx = self . params . n_ctx
self . lparams . n_parts = self . params . n_parts
self . lparams . seed = self . params . seed
self . lparams . memory_f16 = self . params . memory_f16
self . lparams . use_mlock = self . params . use_mlock
2023-04-10 16:35:38 +02:00
self . lparams . use_mmap = self . params . use_mmap
2023-04-06 15:30:57 +02:00
2023-09-07 17:50:47 -04:00
self . model = llama_cpp . llama_load_model_from_file (
self . params . model . encode ( " utf8 " ) , self . lparams )
2024-01-15 09:46:35 -06:00
# Context Params.
self . cparams = llama_cpp . llama_context_default_params ( )
self . ctx = llama_cpp . llama_new_context_with_model ( self . model , self . cparams )
2023-04-07 13:32:19 +02:00
if ( not self . ctx ) :
2023-04-06 15:30:57 +02:00
raise RuntimeError ( f " error: failed to load model ' { self . params . model } ' " )
2023-05-04 18:33:08 +02:00
if ( self . params . ignore_eos ) :
self . params . logit_bias [ llama_cpp . llama_token_eos ( ) ] = - float ( " inf " )
if ( len ( self . params . lora_adapter ) > 0 ) :
if ( llama_cpp . llama_apply_lora_from_file (
self . ctx ,
2023-05-08 15:27:42 +02:00
self . params . lora_adapter . encode ( " utf8 " ) ,
self . params . lora_base . encode ( " utf8 " ) if len ( self . params . lora_base ) > 0 else None ,
2023-05-04 18:33:08 +02:00
self . params . n_threads
) != 0 ) :
print ( " error: failed to apply lora adapter " )
return
2023-04-06 15:30:57 +02:00
print ( file = sys . stderr )
print ( f " system_info: n_threads = { self . params . n_threads } / { cpu_count ( ) } \
2023-04-28 12:50:30 +02:00
| { llama_cpp . llama_print_system_info ( ) . decode ( ' utf8 ' ) } " , file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine the required inference memory per token:
2023-04-06 15:30:57 +02:00
if ( self . params . mem_test ) :
tmp = [ 0 , 1 , 2 , 3 ]
llama_cpp . llama_eval ( self . ctx , ( llama_cpp . c_int * len ( tmp ) ) ( * tmp ) , len ( tmp ) , 0 , self . n_threads )
llama_cpp . llama_print_timings ( self . ctx )
self . exit ( )
return
# create internal context
self . n_ctx = llama_cpp . llama_n_ctx ( self . ctx )
# Add a space in front of the first character to match OG llama tokenizer behavior
self . params . prompt = " " + self . params . prompt
2023-04-07 13:32:19 +02:00
# Load prompt file
if ( self . params . file ) :
with open ( self . params . file ) as f :
self . params . prompt = f . read ( )
2023-05-04 18:33:08 +02:00
self . session_tokens : list [ llama_cpp . llama_token ] = [ ]
if ( len ( self . params . path_session ) > 0 ) :
print ( f " attempting to load saved session from ' { self . params . path_session } ' " , file = sys . stderr )
if ( path . exists ( self . params . path_session ) ) :
_session_tokens = ( llama_cpp . llama_token * ( self . params . n_ctx ) ) ( )
2023-05-08 15:27:03 +02:00
_n_token_count_out = llama_cpp . c_size_t ( )
2023-05-04 18:33:08 +02:00
if ( llama_cpp . llama_load_session_file (
self . ctx ,
self . params . path_session . encode ( " utf8 " ) ,
_session_tokens ,
self . params . n_ctx ,
ctypes . byref ( _n_token_count_out )
2023-05-08 15:27:03 +02:00
) != 1 ) :
2023-05-04 18:33:08 +02:00
print ( f " error: failed to load session file ' { self . params . path_session } ' " , file = sys . stderr )
return
2023-05-08 15:27:03 +02:00
_n_token_count_out = _n_token_count_out . value
2023-05-04 18:33:08 +02:00
self . session_tokens = _session_tokens [ : _n_token_count_out ]
print ( f " loaded a session with prompt size of { _n_token_count_out } tokens " , file = sys . stderr )
else :
print ( f " session file does not exist, will create " , file = sys . stderr )
2023-04-06 15:30:57 +02:00
# tokenize the prompt
2023-04-07 13:32:19 +02:00
self . embd = [ ]
2023-04-06 15:30:57 +02:00
self . embd_inp = self . _tokenize ( self . params . prompt )
2023-05-04 18:33:08 +02:00
if ( len ( self . embd_inp ) > self . n_ctx - 4 ) :
2023-04-06 15:30:57 +02:00
raise RuntimeError ( f " error: prompt is too long ( { len ( self . embd_inp ) } tokens, max { self . params . n_ctx - 4 } ) " )
2023-05-04 18:33:08 +02:00
# debug message about similarity of saved session, if applicable
2023-05-08 15:27:03 +02:00
self . n_matching_session_tokens = 0
2023-05-04 18:33:08 +02:00
if len ( self . session_tokens ) > 0 :
for id in self . session_tokens :
2023-05-08 15:27:03 +02:00
if self . n_matching_session_tokens > = len ( self . embd_inp ) or id != self . embd_inp [ self . n_matching_session_tokens ] :
2023-05-04 18:33:08 +02:00
break
2023-05-08 15:27:03 +02:00
self . n_matching_session_tokens + = 1
2023-05-04 18:33:08 +02:00
2023-05-08 15:27:03 +02:00
if self . n_matching_session_tokens > = len ( self . embd_inp ) :
2023-05-04 18:33:08 +02:00
print ( f " session file has exact match for prompt! " )
2023-05-08 15:27:03 +02:00
elif self . n_matching_session_tokens < ( len ( self . embd_inp ) / 2 ) :
print ( f " warning: session file has low similarity to prompt ( { self . n_matching_session_tokens } / { len ( self . embd_inp ) } tokens); will mostly be reevaluated " )
2023-05-04 18:33:08 +02:00
else :
2023-05-08 15:27:03 +02:00
print ( f " session file matches { self . n_matching_session_tokens } / { len ( self . embd_inp ) } tokens of prompt " )
self . need_to_save_session = len ( self . params . path_session ) > 0 and self . n_matching_session_tokens < ( len ( self . embd_inp ) * 3 / 4 )
2023-05-04 18:33:08 +02:00
2023-04-06 15:30:57 +02:00
# number of tokens to keep when resetting context
if ( self . params . n_keep < 0 or self . params . n_keep > len ( self . embd_inp ) or self . params . instruct ) :
self . params . n_keep = len ( self . embd_inp )
self . inp_prefix = self . _tokenize ( self . params . instruct_inp_prefix )
self . inp_suffix = self . _tokenize ( self . params . instruct_inp_suffix , False )
# in instruct mode, we inject a prefix and a suffix to each input by the user
2023-05-04 18:33:08 +02:00
self . antiecho = None
2023-04-06 15:30:57 +02:00
if ( self . params . instruct ) :
self . params . interactive_start = True
2023-04-10 16:35:38 +02:00
_ptn = self . _tokenize ( self . params . instruct_inp_prefix . strip ( ) , False )
self . first_antiprompt . append ( _ptn )
2023-05-06 15:16:58 +02:00
self . antiecho = util . IterSearch ( _ptn )
2023-04-06 15:30:57 +02:00
# enable interactive mode if reverse prompt or interactive start is specified
if ( len ( self . params . antiprompt ) != 0 or self . params . interactive_start ) :
self . params . interactive = True
2023-04-03 22:54:46 +02:00
# determine newline token
2023-04-04 11:48:48 +02:00
self . llama_token_newline = self . _tokenize ( " \n " , False )
2023-04-17 14:45:28 +02:00
self . llama_token_eot = self . _tokenize ( " [end of text] \n " , False )
2023-04-04 11:48:48 +02:00
2023-04-06 15:30:57 +02:00
if ( self . params . verbose_prompt ) :
print ( f """
prompt : ' {self.params.prompt} '
number of tokens in prompt = { len ( self . embd_inp ) } """ , file=sys.stderr)
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
for i in range ( len ( self . embd_inp ) ) :
2023-09-07 17:50:47 -04:00
print ( f " { self . embd_inp [ i ] } -> ' { self . token_to_str ( self . embd_inp [ i ] ) } ' " , file = sys . stderr )
2023-04-04 11:48:48 +02:00
2023-04-06 15:30:57 +02:00
if ( self . params . n_keep > 0 ) :
print ( " static prompt based on n_keep: ' " )
for i in range ( self . params . n_keep ) :
2023-09-07 17:50:47 -04:00
print ( self . token_to_str ( self . embd_inp [ i ] ) , file = sys . stderr )
2023-04-06 15:30:57 +02:00
print ( " ' " , file = sys . stderr )
print ( file = sys . stderr )
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
if ( self . params . interactive ) :
print ( " interactive mode on. " , file = sys . stderr )
if ( len ( self . params . antiprompt ) > 0 ) :
for antiprompt in self . params . antiprompt :
print ( f " Reverse prompt: ' { antiprompt } ' " , file = sys . stderr )
if len ( self . params . input_prefix ) > 0 :
print ( f " Input prefix: ' { self . params . input_prefix } ' " , file = sys . stderr )
2023-05-04 18:33:08 +02:00
print ( f """ sampling: repeat_last_n = { self . params . repeat_last_n } , \
repeat_penalty = { self . params . repeat_penalty } , \
presence_penalty = { self . params . presence_penalty } , \
frequency_penalty = { self . params . frequency_penalty } , \
2023-04-06 15:30:57 +02:00
top_k = { self . params . top_k } , \
2023-05-04 18:33:08 +02:00
tfs_z = { self . params . tfs_z } , \
2023-04-06 15:30:57 +02:00
top_p = { self . params . top_p } , \
2023-05-04 18:33:08 +02:00
typical_p = { self . params . typical_p } , \
temp = { self . params . temp } , \
mirostat = { self . params . mirostat } , \
mirostat_lr = { self . params . mirostat_eta } , \
mirostat_ent = { self . params . mirostat_tau } , \
generate : n_ctx = { self . n_ctx } , \
n_batch = { self . params . n_batch } , \
n_predict = { self . params . n_predict } , \
2023-04-06 15:30:57 +02:00
n_keep = { self . params . n_keep }
2023-05-04 18:33:08 +02:00
2023-04-06 15:30:57 +02:00
""" , file=sys.stderr)
2023-04-03 22:54:46 +02:00
# determine antiprompt tokens
2023-04-06 15:30:57 +02:00
for i in self . params . antiprompt :
2023-04-04 11:48:48 +02:00
self . first_antiprompt . append ( self . _tokenize ( i , False ) )
2023-04-06 15:30:57 +02:00
self . last_n_tokens = [ 0 ] * self . n_ctx #TODO: deque doesnt support slices
if ( params . interactive ) :
print ( """ == Running in interactive mode. ==
- Press Ctrl + C to interject at any time .
- Press Return to return control to LLaMa .
- If you want to submit another line , end your input in ' \\ ' .
""" , file=sys.stderr)
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_PROMPT )
2023-04-06 15:30:57 +02:00
2023-04-04 11:48:48 +02:00
# tokenize a prompt
def _tokenize ( self , prompt , bos = True ) :
2023-05-06 15:16:58 +02:00
_arr = ( llama_cpp . llama_token * ( ( len ( prompt ) + 1 ) * 4 ) ) ( )
2024-01-15 09:46:35 -06:00
_n = llama_cpp . llama_tokenize ( self . model , prompt . encode ( " utf8 " , errors = " ignore " ) , len ( prompt ) , _arr , len ( _arr ) , bos , False )
2023-04-04 11:48:48 +02:00
return _arr [ : _n ]
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
def set_color ( self , c ) :
if ( self . params . use_color ) :
2023-04-06 15:33:22 +02:00
print ( c , end = " " )
2023-04-06 15:30:57 +02:00
2023-04-17 14:45:28 +02:00
def use_antiprompt ( self ) :
return len ( self . first_antiprompt ) > 0
2023-04-04 11:48:48 +02:00
# generate tokens
2023-04-03 22:54:46 +02:00
def generate ( self ) :
2023-04-17 14:45:28 +02:00
while self . remaining_tokens > 0 or self . params . interactive or self . params . n_predict == - 1 :
2023-04-03 22:54:46 +02:00
# predict
if len ( self . embd ) > 0 :
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if ( self . n_past + len ( self . embd ) > self . n_ctx ) :
2023-04-06 15:30:57 +02:00
n_left = self . n_past - self . params . n_keep
self . n_past = self . params . n_keep
2023-04-03 22:54:46 +02:00
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self . last_n_tokens [
2023-04-04 16:18:26 +02:00
self . n_ctx - int ( n_left / 2 ) - len ( self . embd ) : - len ( self . embd )
2023-04-03 22:54:46 +02:00
]
2023-04-04 16:18:26 +02:00
self . embd = _insert + self . embd
2023-05-04 18:33:08 +02:00
self . params . path_session = " "
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if self . n_session_consumed < len ( self . session_tokens ) :
for i in range ( len ( self . embd ) ) :
if self . embd [ i ] != self . session_tokens [ self . n_session_consumed ] :
self . session_tokens = self . session_tokens [ : self . n_session_consumed ]
break
self . n_past + = 1
self . n_session_consumed + = 1
if self . n_session_consumed > = len ( self . session_tokens ) :
i + = 1
break
if i > 0 :
self . embd = self . embd [ i : ]
# evaluate tokens in batches
# embd is typically prepared beforehand to fit within a batch, but not always
#TODO BUG: The batching code causes nonsensical generation
""" for i in range(0, len(self.embd), self.params.n_batch):
n_eval = self . params . n_batch
_arr = ( llama_cpp . llama_token * n_eval ) ( * self . embd [ i : i + n_eval ] )
if llama_cpp . llama_eval ( self . ctx , _arr , n_eval , self . n_past , self . params . n_threads ) != 0 :
print ( f " failed to eval " )
return
self . n_past + = n_eval """
2023-04-03 22:54:46 +02:00
if ( llama_cpp . llama_eval (
2024-01-15 09:46:35 -06:00
self . ctx , ( llama_cpp . llama_token * len ( self . embd ) ) ( * self . embd ) , len ( self . embd ) , self . n_past
2023-04-03 22:54:46 +02:00
) != 0 ) :
raise Exception ( " Failed to llama_eval! " )
2023-05-08 15:27:03 +02:00
if len ( self . embd ) > 0 and len ( self . params . path_session ) > 0 :
2023-05-04 18:33:08 +02:00
self . session_tokens . extend ( self . embd )
self . n_session_consumed = len ( self . session_tokens )
2023-04-03 22:54:46 +02:00
self . n_past + = len ( self . embd )
self . embd = [ ]
2023-05-04 18:33:08 +02:00
if len ( self . embd_inp ) < = self . input_consumed : #&& !is_interacting
2023-04-03 22:54:46 +02:00
# out of user input, sample next token
2023-05-04 18:33:08 +02:00
top_k = llama_cpp . llama_n_vocab ( self . ctx ) if self . params . top_k < = 0 else self . params . top_k
repeat_last_n = self . n_ctx if self . params . repeat_last_n < 0 else self . params . repeat_last_n
# optionally save the session on first sample (for faster prompt loading next time)
if len ( self . params . path_session ) > 0 and self . need_to_save_session :
self . need_to_save_session = False
llama_cpp . llama_save_session_file (
self . ctx ,
self . params . path_session . encode ( " utf8 " ) ,
2023-05-08 15:27:03 +02:00
( llama_cpp . llama_token * len ( self . session_tokens ) ) ( * self . session_tokens ) ,
2023-05-04 18:33:08 +02:00
len ( self . session_tokens )
)
id = 0
logits = llama_cpp . llama_get_logits ( self . ctx )
2024-01-15 09:46:35 -06:00
n_vocab = llama_cpp . llama_n_vocab ( self . model )
2023-05-04 18:33:08 +02:00
# Apply params.logit_bias map
for key , value in self . params . logit_bias . items ( ) :
2023-05-06 22:22:28 +02:00
logits [ key ] + = value
2023-05-04 18:33:08 +02:00
_arr = ( llama_cpp . llama_token_data * n_vocab ) ( * [
llama_cpp . llama_token_data ( token_id , logits [ token_id ] , 0.0 )
for token_id in range ( n_vocab )
] )
candidates_p = llama_cpp . ctypes . pointer ( llama_cpp . llama_token_data_array ( _arr , len ( _arr ) , False ) )
# Apply penalties
2023-09-07 17:50:47 -04:00
nl_logit = logits [ llama_cpp . llama_token_nl ( self . ctx ) ]
2023-05-04 18:33:08 +02:00
last_n_repeat = min ( len ( self . last_n_tokens ) , repeat_last_n , self . n_ctx )
_arr = ( llama_cpp . llama_token * last_n_repeat ) ( * self . last_n_tokens [ len ( self . last_n_tokens ) - last_n_repeat : ] )
2024-01-15 09:46:35 -06:00
llama_cpp . llama_sample_repetition_penalties (
ctx = self . ctx ,
candidates = candidates_p ,
last_tokens_data = _arr ,
penalty_last_n = last_n_repeat ,
penalty_repeat = llama_cpp . c_float ( self . params . repeat_penalty ) ,
penalty_freq = llama_cpp . c_float ( self . params . frequency_penalty ) ,
penalty_present = llama_cpp . c_float ( self . params . presence_penalty ) ,
)
# NOT PRESENT IN CURRENT VERSION ?
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
# _arr,
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
2023-05-04 18:33:08 +02:00
if not self . params . penalize_nl :
logits [ llama_cpp . llama_token_nl ( ) ] = nl_logit
2023-05-06 13:35:50 +02:00
2023-05-04 18:33:08 +02:00
if self . params . temp < = 0 :
# Greedy sampling
id = llama_cpp . llama_sample_token_greedy ( self . ctx , candidates_p )
else :
if self . params . mirostat == 1 :
mirostat_mu = 2.0 * self . params . mirostat_tau
mirostat_m = 100
2023-05-06 13:35:50 +02:00
llama_cpp . llama_sample_temperature ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . temp ) )
id = llama_cpp . llama_sample_token_mirostat ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . mirostat_tau ) , llama_cpp . c_float ( self . params . mirostat_eta ) , llama_cpp . c_int ( mirostat_m ) , llama_cpp . c_float ( mirostat_mu ) )
2023-05-04 18:33:08 +02:00
elif self . params . mirostat == 2 :
mirostat_mu = 2.0 * self . params . mirostat_tau
2023-05-06 13:35:50 +02:00
llama_cpp . llama_sample_temperature ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . temp ) )
id = llama_cpp . llama_sample_token_mirostat_v2 ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . mirostat_tau ) , llama_cpp . c_float ( self . params . mirostat_eta ) , llama_cpp . c_float ( mirostat_mu ) )
2023-05-04 18:33:08 +02:00
else :
# Temperature sampling
2023-05-22 23:54:57 -07:00
llama_cpp . llama_sample_top_k ( self . ctx , candidates_p , top_k , min_keep = llama_cpp . c_size_t ( 1 ) )
llama_cpp . llama_sample_tail_free ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . tfs_z ) , min_keep = llama_cpp . c_size_t ( 1 ) )
llama_cpp . llama_sample_typical ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . typical_p ) , min_keep = llama_cpp . c_size_t ( 1 ) )
llama_cpp . llama_sample_top_p ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . top_p ) , min_keep = llama_cpp . c_size_t ( 1 ) )
2023-05-06 13:35:50 +02:00
llama_cpp . llama_sample_temperature ( self . ctx , candidates_p , llama_cpp . c_float ( self . params . temp ) )
2023-05-04 18:33:08 +02:00
id = llama_cpp . llama_sample_token ( self . ctx , candidates_p )
# print("`{}`".format(candidates_p.size))
2023-04-06 15:30:57 +02:00
2023-04-03 22:54:46 +02:00
self . last_n_tokens . pop ( 0 )
2023-04-04 11:48:48 +02:00
self . last_n_tokens . append ( id )
2023-04-03 22:54:46 +02:00
# replace end of text token with newline token when in interactive mode
2023-09-07 17:50:47 -04:00
if ( id == llama_cpp . llama_token_eos ( self . ctx ) and self . params . interactive and not self . params . instruct ) :
2023-04-03 22:54:46 +02:00
id = self . llama_token_newline [ 0 ]
2023-05-26 06:35:15 -07:00
self . embd . append ( id )
2023-04-06 15:30:57 +02:00
if ( self . use_antiprompt ( ) ) :
# tokenize and inject first reverse prompt
self . embd_inp + = self . first_antiprompt [ 0 ]
2023-05-26 06:35:15 -07:00
for id in self . first_antiprompt [ 0 ] :
self . embd . append ( id )
else :
# add it to the context
self . embd . append ( id )
2023-04-03 22:54:46 +02:00
# echo this to console
self . output_echo = True
# decrement remaining sampling budget
self . remaining_tokens - = 1
else :
# output to console if input echo is on
2023-04-06 15:30:57 +02:00
self . output_echo = self . params . input_echo
2023-04-03 22:54:46 +02:00
# some user input remains from prompt or interaction, forward it to processing
while len ( self . embd_inp ) > self . input_consumed :
2023-04-04 11:48:48 +02:00
self . embd . append ( self . embd_inp [ self . input_consumed ] )
2023-04-03 22:54:46 +02:00
self . last_n_tokens . pop ( 0 )
2023-04-04 11:48:48 +02:00
self . last_n_tokens . append ( self . embd_inp [ self . input_consumed ] )
2023-04-03 22:54:46 +02:00
self . input_consumed + = 1
2023-04-06 15:30:57 +02:00
if len ( self . embd ) > = self . params . n_batch :
2023-04-03 22:54:46 +02:00
break
# display tokens
if self . output_echo :
for id in self . embd :
2023-05-04 18:33:08 +02:00
if self . antiecho != None :
2023-04-10 16:35:38 +02:00
for r in self . antiecho ( id ) :
yield r
else :
yield id
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
# reset color to default if we there is no pending user input
if ( self . params . input_echo and len ( self . embd_inp ) == self . input_consumed ) :
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_DEFAULT )
2023-04-06 15:30:57 +02:00
if ( self . params . interactive and len ( self . embd_inp ) < = self . input_consumed ) :
2023-04-04 11:48:48 +02:00
# if antiprompt is present, stop
if ( self . use_antiprompt ( ) ) :
2023-04-05 14:47:24 +02:00
if True in [
i == self . last_n_tokens [ - len ( i ) : ]
for i in self . first_antiprompt
] :
break
2023-04-04 11:48:48 +02:00
# if we are using instruction mode, and we have processed the initial prompt
2023-04-10 16:35:38 +02:00
if ( self . params . interactive_start ) :
2023-04-04 11:48:48 +02:00
break
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
# end of text token
2023-09-07 17:50:47 -04:00
if len ( self . embd ) > 0 and self . embd [ - 1 ] == llama_cpp . llama_token_eos ( self . ctx ) :
2023-04-06 15:30:57 +02:00
if ( not self . params . instruct ) :
2023-04-17 14:45:28 +02:00
for i in self . llama_token_eot :
2023-04-06 15:30:57 +02:00
yield i
2023-05-04 18:33:08 +02:00
break
2023-04-03 22:54:46 +02:00
# respect n_predict even if antiprompt is present
2023-04-06 15:30:57 +02:00
if ( self . params . interactive and self . remaining_tokens < = 0 and self . params . n_predict != - 1 ) :
# If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self . params . instruct :
2023-04-04 17:54:47 +02:00
self . embd_inp + = self . first_antiprompt [ 0 ]
2023-04-06 15:30:57 +02:00
self . n_remain = self . params . n_predict
2023-04-03 22:54:46 +02:00
break
2023-04-06 15:30:57 +02:00
self . params . interactive_start = False
2023-04-05 14:47:24 +02:00
2023-04-04 16:18:26 +02:00
def __enter__ ( self ) :
return self
def __exit__ ( self , type , value , tb ) :
2023-04-06 15:30:57 +02:00
self . exit ( )
def exit ( self ) :
2023-04-04 16:18:26 +02:00
llama_cpp . llama_free ( self . ctx )
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_DEFAULT )
2023-04-04 16:18:26 +02:00
2023-09-07 17:50:47 -04:00
def token_to_str ( self , token_id : int ) - > bytes :
size = 32
buffer = ( ctypes . c_char * size ) ( )
2024-01-15 09:46:35 -06:00
n = llama_cpp . llama_token_to_piece (
2023-09-07 17:50:47 -04:00
self . model , llama_cpp . llama_token ( token_id ) , buffer , size )
assert n < = size
return bytes ( buffer [ : n ] )
2023-04-04 11:48:48 +02:00
# return past text
2023-04-03 22:54:46 +02:00
def past ( self ) :
for id in self . last_n_tokens [ - self . n_past : ] :
2023-09-07 17:50:47 -04:00
yield self . token_to_str ( id ) . decode ( " utf8 " , errors = " ignore " )
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write input
2023-04-03 22:54:46 +02:00
def input ( self , prompt : str ) :
2023-04-06 15:30:57 +02:00
if ( self . params . instruct and self . last_n_tokens [ - len ( self . inp_prefix ) : ] != self . inp_prefix ) :
2023-04-04 11:48:48 +02:00
self . embd_inp + = self . inp_prefix
2023-04-04 16:18:26 +02:00
self . embd_inp + = self . _tokenize ( prompt )
2023-04-06 15:30:57 +02:00
if ( self . params . instruct ) :
2023-04-04 11:48:48 +02:00
self . embd_inp + = self . inp_suffix
2023-04-03 22:54:46 +02:00
2023-04-04 11:48:48 +02:00
# write output
2023-04-03 22:54:46 +02:00
def output ( self ) :
2023-04-06 15:30:57 +02:00
self . remaining_tokens = self . params . n_predict
2023-04-03 22:54:46 +02:00
for id in self . generate ( ) :
2023-09-07 17:50:47 -04:00
cur_char = self . token_to_str ( id )
2023-05-06 15:16:58 +02:00
# Add remainder of missing bytes
if None in self . multibyte_fix :
self . multibyte_fix [ self . multibyte_fix . index ( None ) ] = cur_char
# Return completed utf char
if len ( self . multibyte_fix ) > 0 and not None in self . multibyte_fix :
yield ( b " " . join ( self . multibyte_fix ) ) . decode ( " utf8 " )
self . multibyte_fix = [ ]
continue
# Contains multi-byte UTF8
for num , pattern in [ ( 2 , 192 ) , ( 3 , 224 ) , ( 4 , 240 ) ] :
# Bitwise AND check
2023-05-22 23:56:25 -07:00
if pattern & int . from_bytes ( cur_char , ' little ' ) == pattern :
2023-05-06 15:16:58 +02:00
self . multibyte_fix = [ cur_char ] + ( [ None ] * ( num - 1 ) )
# Stop incomplete bytes from passing
if len ( self . multibyte_fix ) > 0 :
continue
yield cur_char . decode ( " utf8 " )
2023-04-03 22:54:46 +02:00
2023-04-06 15:30:57 +02:00
# read user input
def read_input ( self ) :
out = " "
while ( t := input ( ) ) . endswith ( " \\ " ) :
out + = t [ : - 1 ] + " \n "
return out + t + " \n "
# interactive mode
def interact ( self ) :
for i in self . output ( ) :
print ( i , end = " " , flush = True )
self . params . input_echo = False
2024-01-15 09:46:35 -06:00
# Using string instead of tokens to check for antiprompt,
# It is more reliable than tokens for interactive mode.
generated_str = " "
2023-04-06 15:30:57 +02:00
while self . params . interactive :
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_USER_INPUT )
2023-04-06 15:30:57 +02:00
if ( self . params . instruct ) :
print ( ' \n > ' , end = " " )
self . input ( self . read_input ( ) )
else :
print ( self . params . input_prefix , end = " " )
2023-05-06 13:18:25 +02:00
self . input ( f " { self . params . input_prefix } { self . read_input ( ) } { self . params . input_suffix } " )
print ( self . params . input_suffix , end = " " )
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_DEFAULT )
2023-04-06 15:30:57 +02:00
try :
for i in self . output ( ) :
print ( i , end = " " , flush = True )
2024-01-15 09:46:35 -06:00
generated_str + = i
for ap in self . params . antiprompt :
if generated_str . endswith ( ap ) :
raise KeyboardInterrupt
2023-04-06 15:30:57 +02:00
except KeyboardInterrupt :
2023-05-06 15:16:58 +02:00
self . set_color ( util . CONSOLE_COLOR_DEFAULT )
2023-04-06 15:30:57 +02:00
if not self . params . instruct :
print ( self . params . fix_prefix , end = " " )
self . input ( self . params . fix_prefix )
2023-04-03 22:54:46 +02:00
if __name__ == " __main__ " :
from datetime import datetime
USER_NAME = " User "
AI_NAME = " ChatLLaMa "
2023-04-04 11:48:48 +02:00
2023-04-03 22:54:46 +02:00
time_now = datetime . now ( )
prompt = f """ Text transcript of a never ending dialog, where { USER_NAME } interacts with an AI assistant named { AI_NAME } .
{ AI_NAME } is helpful , kind , honest , friendly , good at writing and never fails to answer { USER_NAME } ’ s requests immediately and with details and precision .
2024-01-15 09:46:35 -06:00
Transcript below contains only the recorded dialog between two , without any annotations like ( 30 seconds passed . . . ) or ( to himself ) , just what { USER_NAME } and { AI_NAME } say aloud to each other .
2023-04-03 22:54:46 +02:00
The dialog lasts for years , the entirety of it is shared below . It ' s 10000 pages long.
The transcript only includes text , it does not include markup like HTML and Markdown .
{ USER_NAME } : Hello , { AI_NAME } !
{ AI_NAME } : Hello { USER_NAME } ! How may I help you today ?
{ USER_NAME } : What time is it ?
{ AI_NAME } : It is { time_now . strftime ( " % H: % M " ) } .
{ USER_NAME } : What year is it ?
{ AI_NAME } : We are in { time_now . strftime ( " % Y " ) } .
{ USER_NAME } : What is a cat ?
{ AI_NAME } : A cat is a domestic species of small carnivorous mammal . It is the only domesticated species in the family Felidae .
{ USER_NAME } : Name a color .
{ AI_NAME } : Blue
2024-01-15 09:46:35 -06:00
{ USER_NAME } : """
2023-05-04 18:33:08 +02:00
params = gpt_params_parse ( )
2024-01-15 09:46:35 -06:00
if params . prompt is None and params . file is None :
params . prompt = prompt
2023-04-04 16:18:26 +02:00
2023-04-07 13:32:19 +02:00
with LLaMAInteract ( params ) as m :
2023-04-06 15:30:57 +02:00
m . interact ( )