2023-12-22 11:51:25 +01:00
from __future__ import annotations
import multiprocessing
2024-04-25 21:21:48 -04:00
from typing import Optional , List , Literal , Union , Dict , cast
from typing_extensions import Self
from pydantic import Field , model_validator
2023-12-22 11:51:25 +01:00
from pydantic_settings import BaseSettings
import llama_cpp
# Disable warning for model and model_alias settings
BaseSettings . model_config [ " protected_namespaces " ] = ( )
class ModelSettings ( BaseSettings ) :
2023-12-22 14:37:24 -05:00
""" Model settings used to load a Llama model. """
2023-12-22 11:51:25 +01:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
# Model Params
n_gpu_layers : int = Field (
default = 0 ,
ge = - 1 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU. " ,
)
2024-01-15 12:49:20 -05:00
split_mode : int = Field (
2024-02-25 16:53:58 -05:00
default = llama_cpp . LLAMA_SPLIT_MODE_LAYER ,
2024-01-15 12:49:20 -05:00
description = " The split mode to use. " ,
)
2023-12-22 11:51:25 +01:00
main_gpu : int = Field (
default = 0 ,
ge = 0 ,
description = " Main GPU to use. " ,
)
tensor_split : Optional [ List [ float ] ] = Field (
default = None ,
description = " Split layers across multiple GPUs in proportion. " ,
)
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
)
use_mmap : bool = Field (
2024-02-28 14:27:16 -05:00
default = llama_cpp . llama_supports_mmap ( ) ,
2023-12-22 11:51:25 +01:00
description = " Use mmap. " ,
)
use_mlock : bool = Field (
2024-02-28 14:27:16 -05:00
default = llama_cpp . llama_supports_mlock ( ) ,
2023-12-22 11:51:25 +01:00
description = " Use mlock. " ,
)
2024-01-15 17:29:29 +00:00
kv_overrides : Optional [ List [ str ] ] = Field (
default = None ,
description = " List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false. " ,
)
2024-06-04 22:38:21 +08:00
rpc_servers : Optional [ str ] = Field (
default = None ,
description = " comma seperated list of rpc servers for offloading " ,
)
2023-12-22 11:51:25 +01:00
# Context Params
seed : int = Field (
default = llama_cpp . LLAMA_DEFAULT_SEED , description = " Random seed. -1 for random. "
)
2024-01-16 17:54:06 -06:00
n_ctx : int = Field ( default = 2048 , ge = 0 , description = " The context size. " )
2023-12-22 11:51:25 +01:00
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
n_threads : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 1 ,
2024-04-23 02:35:38 -04:00
description = " The number of threads to use. Use -1 for max cpu threads " ,
2023-12-22 11:51:25 +01:00
)
n_threads_batch : int = Field (
2024-04-17 09:04:33 -05:00
default = max ( multiprocessing . cpu_count ( ) , 1 ) ,
2023-12-22 11:51:25 +01:00
ge = 0 ,
2024-04-23 02:35:38 -04:00
description = " The number of threads to use when batch processing. Use -1 for max cpu threads " ,
2023-12-22 11:51:25 +01:00
)
2024-02-28 14:27:40 -05:00
rope_scaling_type : int = Field (
default = llama_cpp . LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
2023-12-22 11:51:25 +01:00
rope_freq_base : float = Field ( default = 0.0 , description = " RoPE base frequency " )
rope_freq_scale : float = Field (
default = 0.0 , description = " RoPE frequency scaling factor "
)
yarn_ext_factor : float = Field ( default = - 1.0 )
yarn_attn_factor : float = Field ( default = 1.0 )
yarn_beta_fast : float = Field ( default = 32.0 )
yarn_beta_slow : float = Field ( default = 1.0 )
yarn_orig_ctx : int = Field ( default = 0 )
mul_mat_q : bool = Field (
default = True , description = " if true, use experimental mul_mat_q kernels "
)
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
offload_kqv : bool = Field (
2024-01-18 11:08:57 -05:00
default = True , description = " Whether to offload kqv to the GPU. "
2023-12-22 11:51:25 +01:00
)
2024-04-30 09:29:16 -04:00
flash_attn : bool = Field (
default = False , description = " Whether to use flash attention. "
)
2023-12-22 11:51:25 +01:00
# Sampling Params
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
# LoRA Params
lora_base : Optional [ str ] = Field (
default = None ,
description = " Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. " ,
)
lora_path : Optional [ str ] = Field (
default = None ,
description = " Path to a LoRA file to apply to the model. " ,
)
# Backend Params
2024-02-17 00:37:51 -05:00
numa : Union [ bool , int ] = Field (
2023-12-22 11:51:25 +01:00
default = False ,
description = " Enable NUMA support. " ,
)
# Chat Format Params
2024-01-29 14:22:23 -05:00
chat_format : Optional [ str ] = Field (
default = None ,
2023-12-22 11:51:25 +01:00
description = " Chat format to use. " ,
)
clip_model_path : Optional [ str ] = Field (
default = None ,
description = " Path to a CLIP model to use for multi-modal chat completion. " ,
)
# Cache Params
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
cache_type : Literal [ " ram " , " disk " ] = Field (
default = " ram " ,
description = " The type of cache to use. Only used if cache is True. " ,
)
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2024-01-18 21:21:37 -05:00
# Tokenizer Options
hf_tokenizer_config_path : Optional [ str ] = Field (
default = None ,
description = " The path to a HuggingFace tokenizer_config.json file. " ,
)
hf_pretrained_model_name_or_path : Optional [ str ] = Field (
default = None ,
description = " The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained(). " ,
)
2024-02-26 14:35:08 -05:00
# Loading from HuggingFace Model Hub
hf_model_repo_id : Optional [ str ] = Field (
default = None ,
description = " The model repo id to use for the HuggingFace tokenizer model. " ,
)
2024-01-31 14:08:14 -05:00
# Speculative Decoding
draft_model : Optional [ str ] = Field (
default = None ,
description = " Method to use for speculative decoding. One of (prompt-lookup-decoding). " ,
)
draft_model_num_pred_tokens : int = Field (
default = 10 ,
description = " Number of tokens to predict using the draft model. " ,
)
2024-04-01 22:19:28 +08:00
# KV Cache Quantization
type_k : Optional [ int ] = Field (
default = None ,
description = " Type of the key cache quantization. " ,
)
type_v : Optional [ int ] = Field (
default = None ,
description = " Type of the value cache quantization. " ,
)
2023-12-22 11:51:25 +01:00
# Misc
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
2024-04-25 21:21:48 -04:00
@model_validator ( mode = " before " ) # pre=True to ensure this runs before any other validation
def set_dynamic_defaults ( self ) - > Self :
2024-04-23 02:35:38 -04:00
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing . cpu_count ( )
2024-04-25 21:21:48 -04:00
values = cast ( Dict [ str , int ] , self )
2024-04-23 02:35:38 -04:00
if values . get ( ' n_threads ' , 0 ) == - 1 :
values [ ' n_threads ' ] = cpu_count
if values . get ( ' n_threads_batch ' , 0 ) == - 1 :
values [ ' n_threads_batch ' ] = cpu_count
2024-04-25 21:21:48 -04:00
return self
2024-04-23 02:35:38 -04:00
2023-12-22 11:51:25 +01:00
class ServerSettings ( BaseSettings ) :
2023-12-22 14:37:24 -05:00
""" Server settings used to configure the FastAPI and Uvicorn server. """
2023-12-22 11:51:25 +01:00
# Uvicorn Settings
host : str = Field ( default = " localhost " , description = " Listen address " )
port : int = Field ( default = 8000 , description = " Listen port " )
ssl_keyfile : Optional [ str ] = Field (
default = None , description = " SSL key file for HTTPS "
)
ssl_certfile : Optional [ str ] = Field (
default = None , description = " SSL certificate file for HTTPS "
)
# FastAPI Settings
api_key : Optional [ str ] = Field (
default = None ,
description = " API key for authentication. If set all requests need to be authenticated. " ,
)
interrupt_requests : bool = Field (
default = True ,
description = " Whether to interrupt requests when a new request is received. " ,
)
2024-04-17 14:08:19 +00:00
disable_ping_events : bool = Field (
default = False ,
description = " Disable EventSource pings (may be needed for some clients). " ,
)
2024-05-05 12:49:31 -04:00
root_path : str = Field (
default = " " ,
description = " The root path for the server. Useful when running behind a reverse proxy. " ,
)
2023-12-22 11:51:25 +01:00
class Settings ( ServerSettings , ModelSettings ) :
pass
class ConfigFileSettings ( ServerSettings ) :
2023-12-22 14:37:24 -05:00
""" Configuration file format settings. """
2024-02-28 14:27:40 -05:00
models : List [ ModelSettings ] = Field ( default = [ ] , description = " Model configs " )