diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 66caaa9..dcc7be7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -105,6 +105,9 @@ class Llama: draft_model: Optional[LlamaDraftModel] = None, # Tokenizer Override tokenizer: Optional[BaseLlamaTokenizer] = None, + # KV cache quantization + type_k: Optional[int] = None, + type_v: Optional[int] = None, # Misc verbose: bool = True, # Extra Params @@ -172,6 +175,8 @@ class Llama: draft_model: Optional draft model to use for speculative decoding. tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp. verbose: Print verbose output to stderr. + type_k: KV cache data type for K (default: f16) + type_v: KV cache data type for V (default: f16) Raises: ValueError: If the model path does not exist. @@ -298,7 +303,11 @@ class Llama: ) # Must be set to True for speculative decoding self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv - + # KV cache quantization + if type_k is not None: + self.context_params.type_k = type_k + if type_v is not None: + self.context_params.type_v = type_v # Sampling Params self.last_n_tokens_size = last_n_tokens_size @@ -1724,6 +1733,7 @@ class Llama: n_threads=self.context_params.n_threads, n_threads_batch=self.context_params.n_threads_batch, rope_scaling_type=self.context_params.rope_scaling_type, + pooling_type=self.context_params.pooling_type, rope_freq_base=self.context_params.rope_freq_base, rope_freq_scale=self.context_params.rope_freq_scale, yarn_ext_factor=self.context_params.yarn_ext_factor, @@ -1733,6 +1743,7 @@ class Llama: yarn_orig_ctx=self.context_params.yarn_orig_ctx, logits_all=self.context_params.logits_all, embedding=self.context_params.embeddings, + offload_kqv=self.context_params.offload_kqv, # Sampling Params last_n_tokens_size=self.last_n_tokens_size, # LoRA Params @@ -1744,51 +1755,17 @@ class Llama: # Chat Format Params chat_format=self.chat_format, chat_handler=self.chat_handler, + # Speculative Decidng + draft_model=self.draft_model, + # KV cache quantization + type_k=self.context_params.type_k, + type_v=self.context_params.type_v, # Misc verbose=self.verbose, ) def __setstate__(self, state): - self.__init__( - model_path=state["model_path"], - # Model Params - n_gpu_layers=state["n_gpu_layers"], - split_mode=state["split_mode"], - main_gpu=state["main_gpu"], - tensor_split=state["tensor_split"], - vocab_only=state["vocab_only"], - use_mmap=state["use_mmap"], - use_mlock=state["use_mlock"], - kv_overrides=state["kv_overrides"], - # Context Params - seed=state["seed"], - n_ctx=state["n_ctx"], - n_batch=state["n_batch"], - n_threads=state["n_threads"], - n_threads_batch=state["n_threads_batch"], - rope_freq_base=state["rope_freq_base"], - rope_freq_scale=state["rope_freq_scale"], - rope_scaling_type=state["rope_scaling_type"], - yarn_ext_factor=state["yarn_ext_factor"], - yarn_attn_factor=state["yarn_attn_factor"], - yarn_beta_fast=state["yarn_beta_fast"], - yarn_beta_slow=state["yarn_beta_slow"], - yarn_orig_ctx=state["yarn_orig_ctx"], - logits_all=state["logits_all"], - embedding=state["embedding"], - # Sampling Params - last_n_tokens_size=state["last_n_tokens_size"], - # LoRA Params - lora_base=state["lora_base"], - lora_path=state["lora_path"], - # Backend Params - numa=state["numa"], - # Chat Format Params - chat_format=state["chat_format"], - chat_handler=state["chat_handler"], - # Misc - verbose=state["verbose"], - ) + self.__init__(**state) def save_state(self) -> LlamaState: assert self._ctx.ctx is not None diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 1db47be..accc02c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -141,6 +141,70 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa byref = ctypes.byref # type: ignore +# from ggml.h +# // NOTE: always add types at the end of the enum to keep backward compatibility +# enum ggml_type { +# GGML_TYPE_F32 = 0, +# GGML_TYPE_F16 = 1, +# GGML_TYPE_Q4_0 = 2, +# GGML_TYPE_Q4_1 = 3, +# // GGML_TYPE_Q4_2 = 4, support has been removed +# // GGML_TYPE_Q4_3 = 5, support has been removed +# GGML_TYPE_Q5_0 = 6, +# GGML_TYPE_Q5_1 = 7, +# GGML_TYPE_Q8_0 = 8, +# GGML_TYPE_Q8_1 = 9, +# GGML_TYPE_Q2_K = 10, +# GGML_TYPE_Q3_K = 11, +# GGML_TYPE_Q4_K = 12, +# GGML_TYPE_Q5_K = 13, +# GGML_TYPE_Q6_K = 14, +# GGML_TYPE_Q8_K = 15, +# GGML_TYPE_IQ2_XXS = 16, +# GGML_TYPE_IQ2_XS = 17, +# GGML_TYPE_IQ3_XXS = 18, +# GGML_TYPE_IQ1_S = 19, +# GGML_TYPE_IQ4_NL = 20, +# GGML_TYPE_IQ3_S = 21, +# GGML_TYPE_IQ2_S = 22, +# GGML_TYPE_IQ4_XS = 23, +# GGML_TYPE_I8 = 24, +# GGML_TYPE_I16 = 25, +# GGML_TYPE_I32 = 26, +# GGML_TYPE_I64 = 27, +# GGML_TYPE_F64 = 28, +# GGML_TYPE_IQ1_M = 29, +# GGML_TYPE_COUNT, +# }; +GGML_TYPE_F32 = 0 +GGML_TYPE_F16 = 1 +GGML_TYPE_Q4_0 = 2 +GGML_TYPE_Q4_1 = 3 +GGML_TYPE_Q5_0 = 6 +GGML_TYPE_Q5_1 = 7 +GGML_TYPE_Q8_0 = 8 +GGML_TYPE_Q8_1 = 9 +GGML_TYPE_Q2_K = 10 +GGML_TYPE_Q3_K = 11 +GGML_TYPE_Q4_K = 12 +GGML_TYPE_Q5_K = 13 +GGML_TYPE_Q6_K = 14 +GGML_TYPE_Q8_K = 15 +GGML_TYPE_IQ2_XXS = 16 +GGML_TYPE_IQ2_XS = 17 +GGML_TYPE_IQ3_XXS = 18 +GGML_TYPE_IQ1_S = 19 +GGML_TYPE_IQ4_NL = 20 +GGML_TYPE_IQ3_S = 21 +GGML_TYPE_IQ2_S = 22 +GGML_TYPE_IQ4_XS = 23 +GGML_TYPE_I8 = 24 +GGML_TYPE_I16 = 25 +GGML_TYPE_I32 = 26 +GGML_TYPE_I64 = 27 +GGML_TYPE_F64 = 28 +GGML_TYPE_IQ1_M = 29 +GGML_TYPE_COUNT = 30 # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index dace8d5..c24fca6 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -175,6 +175,9 @@ class LlamaProxy: chat_handler=chat_handler, # Speculative Decoding draft_model=draft_model, + # KV Cache Quantization + type_k=settings.type_k, + type_v=settings.type_v, # Tokenizer tokenizer=tokenizer, # Misc diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index daa913f..9ebdd0d 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -159,6 +159,15 @@ class ModelSettings(BaseSettings): default=10, description="Number of tokens to predict using the draft model.", ) + # KV Cache Quantization + type_k: Optional[int] = Field( + default=None, + description="Type of the key cache quantization.", + ) + type_v: Optional[int] = Field( + default=None, + description="Type of the value cache quantization.", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information."