2023-12-22 11:51:25 +01:00
from __future__ import annotations
import multiprocessing
from typing import Optional , List , Literal
from pydantic import Field
from pydantic_settings import BaseSettings
import llama_cpp
# Disable warning for model and model_alias settings
BaseSettings . model_config [ " protected_namespaces " ] = ( )
class ModelSettings ( BaseSettings ) :
2023-12-22 14:37:24 -05:00
""" Model settings used to load a Llama model. """
2023-12-22 11:51:25 +01:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
# Model Params
n_gpu_layers : int = Field (
default = 0 ,
ge = - 1 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU. " ,
)
2024-01-15 12:49:20 -05:00
split_mode : int = Field (
default = llama_cpp . LLAMA_SPLIT_LAYER ,
description = " The split mode to use. " ,
)
2023-12-22 11:51:25 +01:00
main_gpu : int = Field (
default = 0 ,
ge = 0 ,
description = " Main GPU to use. " ,
)
tensor_split : Optional [ List [ float ] ] = Field (
default = None ,
description = " Split layers across multiple GPUs in proportion. " ,
)
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
)
use_mmap : bool = Field (
default = llama_cpp . llama_mmap_supported ( ) ,
description = " Use mmap. " ,
)
use_mlock : bool = Field (
default = llama_cpp . llama_mlock_supported ( ) ,
description = " Use mlock. " ,
)
2024-01-15 17:29:29 +00:00
kv_overrides : Optional [ List [ str ] ] = Field (
default = None ,
description = " List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false. " ,
)
2023-12-22 11:51:25 +01:00
# Context Params
seed : int = Field (
default = llama_cpp . LLAMA_DEFAULT_SEED , description = " Random seed. -1 for random. "
)
2024-01-16 17:54:06 -06:00
n_ctx : int = Field ( default = 2048 , ge = 0 , description = " The context size. " )
2023-12-22 11:51:25 +01:00
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
n_threads : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 1 ,
description = " The number of threads to use. " ,
)
n_threads_batch : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 0 ,
description = " The number of threads to use when batch processing. " ,
)
rope_scaling_type : int = Field ( default = llama_cpp . LLAMA_ROPE_SCALING_UNSPECIFIED )
rope_freq_base : float = Field ( default = 0.0 , description = " RoPE base frequency " )
rope_freq_scale : float = Field (
default = 0.0 , description = " RoPE frequency scaling factor "
)
yarn_ext_factor : float = Field ( default = - 1.0 )
yarn_attn_factor : float = Field ( default = 1.0 )
yarn_beta_fast : float = Field ( default = 32.0 )
yarn_beta_slow : float = Field ( default = 1.0 )
yarn_orig_ctx : int = Field ( default = 0 )
mul_mat_q : bool = Field (
default = True , description = " if true, use experimental mul_mat_q kernels "
)
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
offload_kqv : bool = Field (
2024-01-18 11:08:57 -05:00
default = True , description = " Whether to offload kqv to the GPU. "
2023-12-22 11:51:25 +01:00
)
# Sampling Params
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
# LoRA Params
lora_base : Optional [ str ] = Field (
default = None ,
description = " Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. " ,
)
lora_path : Optional [ str ] = Field (
default = None ,
description = " Path to a LoRA file to apply to the model. " ,
)
# Backend Params
numa : bool = Field (
default = False ,
description = " Enable NUMA support. " ,
)
# Chat Format Params
chat_format : str = Field (
default = " llama-2 " ,
description = " Chat format to use. " ,
)
clip_model_path : Optional [ str ] = Field (
default = None ,
description = " Path to a CLIP model to use for multi-modal chat completion. " ,
)
# Cache Params
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
cache_type : Literal [ " ram " , " disk " ] = Field (
default = " ram " ,
description = " The type of cache to use. Only used if cache is True. " ,
)
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2024-01-18 21:21:37 -05:00
# Tokenizer Options
hf_tokenizer_config_path : Optional [ str ] = Field (
default = None ,
description = " The path to a HuggingFace tokenizer_config.json file. " ,
)
hf_pretrained_model_name_or_path : Optional [ str ] = Field (
default = None ,
description = " The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained(). " ,
)
2023-12-22 11:51:25 +01:00
# Misc
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
class ServerSettings ( BaseSettings ) :
2023-12-22 14:37:24 -05:00
""" Server settings used to configure the FastAPI and Uvicorn server. """
2023-12-22 11:51:25 +01:00
# Uvicorn Settings
host : str = Field ( default = " localhost " , description = " Listen address " )
port : int = Field ( default = 8000 , description = " Listen port " )
ssl_keyfile : Optional [ str ] = Field (
default = None , description = " SSL key file for HTTPS "
)
ssl_certfile : Optional [ str ] = Field (
default = None , description = " SSL certificate file for HTTPS "
)
# FastAPI Settings
api_key : Optional [ str ] = Field (
default = None ,
description = " API key for authentication. If set all requests need to be authenticated. " ,
)
interrupt_requests : bool = Field (
default = True ,
description = " Whether to interrupt requests when a new request is received. " ,
)
class Settings ( ServerSettings , ModelSettings ) :
pass
class ConfigFileSettings ( ServerSettings ) :
2023-12-22 14:37:24 -05:00
""" Configuration file format settings. """
2023-12-22 11:51:25 +01:00
models : List [ ModelSettings ] = Field (
2023-12-22 14:37:24 -05:00
default = [ ] , description = " Model configs "
2023-12-22 11:51:25 +01:00
)