2023-04-06 15:30:57 +02:00
import os
import argparse
2023-05-04 18:33:08 +02:00
import re
2023-04-06 15:30:57 +02:00
from dataclasses import dataclass , field
2023-05-04 18:33:08 +02:00
from typing import List
2023-04-06 15:30:57 +02:00
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
@dataclass
class GptParams :
seed : int = - 1
n_threads : int = min ( 4 , os . cpu_count ( ) or 1 )
n_predict : int = 128
n_parts : int = - 1
n_ctx : int = 512
n_batch : int = 8
n_keep : int = 0
2023-05-04 18:33:08 +02:00
ignore_eos : bool = False
logit_bias : dict [ int , float ] = field ( default_factory = dict )
2023-04-06 15:30:57 +02:00
top_k : int = 40
top_p : float = 0.95
2023-05-04 18:33:08 +02:00
tfs_z : float = 1.00
typical_p : float = 1.00
2023-04-06 15:30:57 +02:00
temp : float = 0.80
repeat_penalty : float = 1.10
2023-05-04 18:33:08 +02:00
repeat_last_n : int = 64
frequency_penalty : float = 0.0
presence_penalty : float = 0.0
mirostat : int = 0
mirostat_tau : float = 5.0
mirostat_eta : float = 0.1
2023-04-06 15:30:57 +02:00
model : str = " ./models/llama-7B/ggml-model.bin "
prompt : str = " "
2023-05-04 18:33:08 +02:00
path_session : str = " "
2023-04-06 15:30:57 +02:00
input_prefix : str = " "
2023-05-06 13:18:25 +02:00
input_suffix : str = " "
2023-04-06 15:30:57 +02:00
antiprompt : List [ str ] = field ( default_factory = list )
2023-05-04 18:33:08 +02:00
lora_adapter : str = " "
lora_base : str = " "
2023-04-06 15:30:57 +02:00
memory_f16 : bool = True
random_prompt : bool = False
use_color : bool = False
interactive : bool = False
embedding : bool = False
interactive_start : bool = False
instruct : bool = False
2023-05-04 18:33:08 +02:00
penalize_nl : bool = True
2023-04-06 15:30:57 +02:00
perplexity : bool = False
2023-04-10 16:35:38 +02:00
use_mmap : bool = True
2023-04-06 15:30:57 +02:00
use_mlock : bool = False
mem_test : bool = False
verbose_prompt : bool = False
2023-04-07 13:32:19 +02:00
file : str = None
# If chat ended prematurely, append this to the conversation to fix it.
# Set to "\nUser:" etc.
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
2023-04-10 16:38:45 +02:00
fix_prefix : str = " "
2023-04-07 13:32:19 +02:00
input_echo : bool = True ,
2023-04-06 15:30:57 +02:00
# Default instructions for Alpaca
# switch to "Human" and "Assistant" for Vicuna.
2023-04-07 13:32:19 +02:00
# TODO: TBD how they are gonna handle this upstream
instruct_inp_prefix : str = " \n \n ### Instruction: \n \n "
instruct_inp_suffix : str = " \n \n ### Response: \n \n "
2023-04-06 15:30:57 +02:00
2023-05-04 18:33:08 +02:00
def gpt_params_parse ( argv = None ) :
2023-04-07 13:32:19 +02:00
parser = argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( " -s " , " --seed " , type = int , default = - 1 , help = " RNG seed (use random seed for <= 0) " , dest = " seed " )
parser . add_argument ( " -t " , " --threads " , type = int , default = min ( 4 , os . cpu_count ( ) or 1 ) , help = " number of threads to use during computation " , dest = " n_threads " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " -n " , " --n_predict " , type = int , default = 128 , help = " number of tokens to predict (-1 = infinity) " , dest = " n_predict " )
parser . add_argument ( " --n_parts " , type = int , default = - 1 , help = " number of model parts " , dest = " n_parts " )
2023-04-07 13:32:19 +02:00
parser . add_argument ( " -c " , " --ctx_size " , type = int , default = 512 , help = " size of the prompt context " , dest = " n_ctx " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " -b " , " --batch_size " , type = int , default = 8 , help = " batch size for prompt processing " , dest = " n_batch " )
parser . add_argument ( " --keep " , type = int , default = 0 , help = " number of tokens to keep from the initial prompt " , dest = " n_keep " )
parser . add_argument (
" -l " ,
" --logit-bias " ,
type = str ,
action = ' append ' ,
help = " --logit-bias TOKEN_ID(+/-)BIAS " ,
dest = " logit_bias_str "
)
parser . add_argument ( " --ignore-eos " , action = " store_true " , help = " ignore end of stream token and continue generating " , dest = " ignore_eos " )
2023-04-07 13:32:19 +02:00
parser . add_argument ( " --top_k " , type = int , default = 40 , help = " top-k sampling " , dest = " top_k " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --top_p " , type = float , default = 0.95 , help = " top-p samplin " , dest = " top_p " )
parser . add_argument ( " --tfs " , type = float , default = 1.0 , help = " tail free sampling, parameter z (1.0 = disabled) " , dest = " tfs_z " )
2023-04-07 13:32:19 +02:00
parser . add_argument ( " --temp " , type = float , default = 0.80 , help = " temperature " , dest = " temp " )
parser . add_argument ( " --repeat_penalty " , type = float , default = 1.10 , help = " penalize repeat sequence of tokens " , dest = " repeat_penalty " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --repeat_last_n " , type = int , default = 64 , help = " last n tokens to consider for penalize " , dest = " repeat_last_n " )
parser . add_argument ( " --frequency_penalty " , type = float , default = 0.0 , help = " repeat alpha frequency penalty (0.0 = disabled) " , dest = " tfs_z " )
parser . add_argument ( " --presence_penalty " , type = float , default = 0.0 , help = " repeat alpha presence penalty (0.0 = disabled) " , dest = " presence_penalty " )
parser . add_argument ( " --mirostat " , type = float , default = 1.0 , help = " use Mirostat sampling. " , dest = " mirostat " )
2023-05-06 15:16:58 +02:00
parser . add_argument ( " --mirostat_ent " , type = float , default = 5.0 , help = " Mirostat target entropy, parameter tau represents the average surprise value " , dest = " mirostat_tau " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --mirostat_lr " , type = float , default = 0.1 , help = " Mirostat learning rate, parameter eta " , dest = " mirostat_eta " )
2023-04-07 13:32:19 +02:00
parser . add_argument ( " -m " , " --model " , type = str , default = " ./models/llama-7B/ggml-model.bin " , help = " model path " , dest = " model " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " -p " , " --prompt " , type = str , default = " " , help = " initial prompt " , dest = " prompt " )
parser . add_argument ( " -f " , " --file " , type = str , default = None , help = " file containing initial prompt to load " , dest = " file " )
parser . add_argument ( " --session " , type = str , default = None , help = " file to cache model state in (may be large!) " , dest = " path_session " )
parser . add_argument ( " --in-prefix " , type = str , default = " " , help = " string to prefix user inputs with " , dest = " input_prefix " )
2023-05-06 13:18:25 +02:00
parser . add_argument ( " --in-suffix " , type = str , default = " " , help = " append to input " , dest = " input_suffix " )
2023-04-06 15:30:57 +02:00
parser . add_argument (
2023-05-04 18:33:08 +02:00
" -r " ,
" --reverse-prompt " ,
type = str ,
action = ' append ' ,
help = " poll user input upon seeing PROMPT (can be \n specified more than once for multiple prompts). " ,
dest = " antiprompt "
2023-04-06 15:30:57 +02:00
)
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --lora " , type = str , default = " " , help = " apply LoRA adapter (implies --no-mmap) " , dest = " lora_adapter " )
parser . add_argument ( " --lora-base " , type = str , default = " " , help = " optional model to use as a base for the layers modified by the LoRA adapter " , dest = " lora_base " )
parser . add_argument ( " --memory_f32 " , action = " store_false " , help = " use f32 instead of f16 for memory key+value " , dest = " memory_f16 " )
parser . add_argument ( " --random-prompt " , action = " store_true " , help = " start with a randomized prompt. " , dest = " random_prompt " )
2023-04-07 13:32:19 +02:00
parser . add_argument (
2023-05-04 18:33:08 +02:00
" --color " ,
2023-04-07 13:32:19 +02:00
action = " store_true " ,
2023-05-04 18:33:08 +02:00
help = " colorise output to distinguish prompt and user input from generations " ,
dest = " use_color "
2023-04-07 13:32:19 +02:00
)
2023-05-04 18:33:08 +02:00
parser . add_argument (
" -i " , " --interactive " , action = " store_true " , help = " run in interactive mode " , dest = " interactive "
)
parser . add_argument ( " --embedding " , action = " store_true " , help = " " , dest = " embedding " )
2023-04-06 15:30:57 +02:00
parser . add_argument (
" --interactive-first " ,
action = " store_true " ,
help = " run in interactive mode and wait for input right away " ,
2023-04-07 13:32:19 +02:00
dest = " interactive_start "
2023-04-06 15:30:57 +02:00
)
2023-05-04 18:33:08 +02:00
2023-04-06 15:30:57 +02:00
parser . add_argument (
" -ins " ,
" --instruct " ,
action = " store_true " ,
help = " run in instruction mode (use with Alpaca or Vicuna models) " ,
dest = " instruct "
)
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --no-penalize-nl " , action = " store_false " , help = " do not penalize newline token " , dest = " penalize_nl " )
parser . add_argument ( " --perplexity " , action = " store_true " , help = " compute perplexity over the prompt " , dest = " perplexity " )
2023-04-10 16:35:38 +02:00
parser . add_argument ( " --no-mmap " , action = " store_false " , help = " do not memory-map model (slower load but may reduce pageouts if not using mlock) " , dest = " use_mmap " )
2023-05-04 18:33:08 +02:00
parser . add_argument ( " --mlock " , action = " store_true " , help = " force system to keep model in RAM rather than swapping or compressing " , dest = " use_mlock " )
2023-04-07 13:32:19 +02:00
parser . add_argument ( " --mtest " , action = " store_true " , help = " compute maximum memory usage " , dest = " mem_test " )
2023-04-10 16:35:38 +02:00
parser . add_argument ( " --verbose-prompt " , action = " store_true " , help = " print prompt before generation " , dest = " verbose_prompt " )
2023-05-04 18:33:08 +02:00
#Custom args
2023-04-07 13:32:19 +02:00
parser . add_argument ( " --fix-prefix " , type = str , default = " " , help = " append to input when generated n_predict tokens " , dest = " fix_prefix " )
parser . add_argument ( " --input-noecho " , action = " store_false " , help = " dont output the input " , dest = " input_echo " )
2023-05-04 18:33:08 +02:00
parser . add_argument (
" --interactive-start " ,
action = " store_true " ,
help = " run in interactive mode " ,
dest = " interactive "
)
2023-04-06 15:30:57 +02:00
args = parser . parse_args ( argv )
2023-05-04 18:33:08 +02:00
logit_bias_str = args . logit_bias_str
delattr ( args , " logit_bias_str " )
params = GptParams ( * * vars ( args ) )
if ( params . lora_adapter ) :
params . use_mmap = False
if ( logit_bias_str != None ) :
for i in logit_bias_str :
if ( m := re . match ( r " ( \ d+)([-+] \ d+) " , i ) ) :
2023-05-06 13:27:52 +02:00
params . logit_bias [ int ( m . group ( 1 ) ) ] = float ( m . group ( 2 ) )
2023-05-04 18:33:08 +02:00
return params
2023-04-06 15:30:57 +02:00
def gpt_random_prompt ( rng ) :
return [
" So " ,
" Once upon a time " ,
" When " ,
" The " ,
" After " ,
" If " ,
" import " ,
" He " ,
" She " ,
" They " ,
] [ rng % 10 ]
if __name__ == " __main__ " :
2023-05-04 18:33:08 +02:00
print ( gpt_params_parse ( ) )