feat: add support for KV cache quantization options (#1307)
* add KV cache quantization options https://github.com/abetlen/llama-cpp-python/discussions/1220 https://github.com/abetlen/llama-cpp-python/issues/1305 * Add ggml_type * Use ggml_type instead of string for quantization * Add server support --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
aa9f1ae011
commit
f165048a69
4 changed files with 94 additions and 41 deletions
|
@ -105,6 +105,9 @@ class Llama:
|
|||
draft_model: Optional[LlamaDraftModel] = None,
|
||||
# Tokenizer Override
|
||||
tokenizer: Optional[BaseLlamaTokenizer] = None,
|
||||
# KV cache quantization
|
||||
type_k: Optional[int] = None,
|
||||
type_v: Optional[int] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
# Extra Params
|
||||
|
@ -172,6 +175,8 @@ class Llama:
|
|||
draft_model: Optional draft model to use for speculative decoding.
|
||||
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
|
||||
verbose: Print verbose output to stderr.
|
||||
type_k: KV cache data type for K (default: f16)
|
||||
type_v: KV cache data type for V (default: f16)
|
||||
|
||||
Raises:
|
||||
ValueError: If the model path does not exist.
|
||||
|
@ -298,7 +303,11 @@ class Llama:
|
|||
) # Must be set to True for speculative decoding
|
||||
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
||||
self.context_params.offload_kqv = offload_kqv
|
||||
|
||||
# KV cache quantization
|
||||
if type_k is not None:
|
||||
self.context_params.type_k = type_k
|
||||
if type_v is not None:
|
||||
self.context_params.type_v = type_v
|
||||
# Sampling Params
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
|
||||
|
@ -1724,6 +1733,7 @@ class Llama:
|
|||
n_threads=self.context_params.n_threads,
|
||||
n_threads_batch=self.context_params.n_threads_batch,
|
||||
rope_scaling_type=self.context_params.rope_scaling_type,
|
||||
pooling_type=self.context_params.pooling_type,
|
||||
rope_freq_base=self.context_params.rope_freq_base,
|
||||
rope_freq_scale=self.context_params.rope_freq_scale,
|
||||
yarn_ext_factor=self.context_params.yarn_ext_factor,
|
||||
|
@ -1733,6 +1743,7 @@ class Llama:
|
|||
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
|
||||
logits_all=self.context_params.logits_all,
|
||||
embedding=self.context_params.embeddings,
|
||||
offload_kqv=self.context_params.offload_kqv,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
|
@ -1744,51 +1755,17 @@ class Llama:
|
|||
# Chat Format Params
|
||||
chat_format=self.chat_format,
|
||||
chat_handler=self.chat_handler,
|
||||
# Speculative Decidng
|
||||
draft_model=self.draft_model,
|
||||
# KV cache quantization
|
||||
type_k=self.context_params.type_k,
|
||||
type_v=self.context_params.type_v,
|
||||
# Misc
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(
|
||||
model_path=state["model_path"],
|
||||
# Model Params
|
||||
n_gpu_layers=state["n_gpu_layers"],
|
||||
split_mode=state["split_mode"],
|
||||
main_gpu=state["main_gpu"],
|
||||
tensor_split=state["tensor_split"],
|
||||
vocab_only=state["vocab_only"],
|
||||
use_mmap=state["use_mmap"],
|
||||
use_mlock=state["use_mlock"],
|
||||
kv_overrides=state["kv_overrides"],
|
||||
# Context Params
|
||||
seed=state["seed"],
|
||||
n_ctx=state["n_ctx"],
|
||||
n_batch=state["n_batch"],
|
||||
n_threads=state["n_threads"],
|
||||
n_threads_batch=state["n_threads_batch"],
|
||||
rope_freq_base=state["rope_freq_base"],
|
||||
rope_freq_scale=state["rope_freq_scale"],
|
||||
rope_scaling_type=state["rope_scaling_type"],
|
||||
yarn_ext_factor=state["yarn_ext_factor"],
|
||||
yarn_attn_factor=state["yarn_attn_factor"],
|
||||
yarn_beta_fast=state["yarn_beta_fast"],
|
||||
yarn_beta_slow=state["yarn_beta_slow"],
|
||||
yarn_orig_ctx=state["yarn_orig_ctx"],
|
||||
logits_all=state["logits_all"],
|
||||
embedding=state["embedding"],
|
||||
# Sampling Params
|
||||
last_n_tokens_size=state["last_n_tokens_size"],
|
||||
# LoRA Params
|
||||
lora_base=state["lora_base"],
|
||||
lora_path=state["lora_path"],
|
||||
# Backend Params
|
||||
numa=state["numa"],
|
||||
# Chat Format Params
|
||||
chat_format=state["chat_format"],
|
||||
chat_handler=state["chat_handler"],
|
||||
# Misc
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
self.__init__(**state)
|
||||
|
||||
def save_state(self) -> LlamaState:
|
||||
assert self._ctx.ctx is not None
|
||||
|
|
|
@ -141,6 +141,70 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
|
|||
|
||||
byref = ctypes.byref # type: ignore
|
||||
|
||||
# from ggml.h
|
||||
# // NOTE: always add types at the end of the enum to keep backward compatibility
|
||||
# enum ggml_type {
|
||||
# GGML_TYPE_F32 = 0,
|
||||
# GGML_TYPE_F16 = 1,
|
||||
# GGML_TYPE_Q4_0 = 2,
|
||||
# GGML_TYPE_Q4_1 = 3,
|
||||
# // GGML_TYPE_Q4_2 = 4, support has been removed
|
||||
# // GGML_TYPE_Q4_3 = 5, support has been removed
|
||||
# GGML_TYPE_Q5_0 = 6,
|
||||
# GGML_TYPE_Q5_1 = 7,
|
||||
# GGML_TYPE_Q8_0 = 8,
|
||||
# GGML_TYPE_Q8_1 = 9,
|
||||
# GGML_TYPE_Q2_K = 10,
|
||||
# GGML_TYPE_Q3_K = 11,
|
||||
# GGML_TYPE_Q4_K = 12,
|
||||
# GGML_TYPE_Q5_K = 13,
|
||||
# GGML_TYPE_Q6_K = 14,
|
||||
# GGML_TYPE_Q8_K = 15,
|
||||
# GGML_TYPE_IQ2_XXS = 16,
|
||||
# GGML_TYPE_IQ2_XS = 17,
|
||||
# GGML_TYPE_IQ3_XXS = 18,
|
||||
# GGML_TYPE_IQ1_S = 19,
|
||||
# GGML_TYPE_IQ4_NL = 20,
|
||||
# GGML_TYPE_IQ3_S = 21,
|
||||
# GGML_TYPE_IQ2_S = 22,
|
||||
# GGML_TYPE_IQ4_XS = 23,
|
||||
# GGML_TYPE_I8 = 24,
|
||||
# GGML_TYPE_I16 = 25,
|
||||
# GGML_TYPE_I32 = 26,
|
||||
# GGML_TYPE_I64 = 27,
|
||||
# GGML_TYPE_F64 = 28,
|
||||
# GGML_TYPE_IQ1_M = 29,
|
||||
# GGML_TYPE_COUNT,
|
||||
# };
|
||||
GGML_TYPE_F32 = 0
|
||||
GGML_TYPE_F16 = 1
|
||||
GGML_TYPE_Q4_0 = 2
|
||||
GGML_TYPE_Q4_1 = 3
|
||||
GGML_TYPE_Q5_0 = 6
|
||||
GGML_TYPE_Q5_1 = 7
|
||||
GGML_TYPE_Q8_0 = 8
|
||||
GGML_TYPE_Q8_1 = 9
|
||||
GGML_TYPE_Q2_K = 10
|
||||
GGML_TYPE_Q3_K = 11
|
||||
GGML_TYPE_Q4_K = 12
|
||||
GGML_TYPE_Q5_K = 13
|
||||
GGML_TYPE_Q6_K = 14
|
||||
GGML_TYPE_Q8_K = 15
|
||||
GGML_TYPE_IQ2_XXS = 16
|
||||
GGML_TYPE_IQ2_XS = 17
|
||||
GGML_TYPE_IQ3_XXS = 18
|
||||
GGML_TYPE_IQ1_S = 19
|
||||
GGML_TYPE_IQ4_NL = 20
|
||||
GGML_TYPE_IQ3_S = 21
|
||||
GGML_TYPE_IQ2_S = 22
|
||||
GGML_TYPE_IQ4_XS = 23
|
||||
GGML_TYPE_I8 = 24
|
||||
GGML_TYPE_I16 = 25
|
||||
GGML_TYPE_I32 = 26
|
||||
GGML_TYPE_I64 = 27
|
||||
GGML_TYPE_F64 = 28
|
||||
GGML_TYPE_IQ1_M = 29
|
||||
GGML_TYPE_COUNT = 30
|
||||
|
||||
# from ggml-backend.h
|
||||
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
|
|
|
@ -175,6 +175,9 @@ class LlamaProxy:
|
|||
chat_handler=chat_handler,
|
||||
# Speculative Decoding
|
||||
draft_model=draft_model,
|
||||
# KV Cache Quantization
|
||||
type_k=settings.type_k,
|
||||
type_v=settings.type_v,
|
||||
# Tokenizer
|
||||
tokenizer=tokenizer,
|
||||
# Misc
|
||||
|
|
|
@ -159,6 +159,15 @@ class ModelSettings(BaseSettings):
|
|||
default=10,
|
||||
description="Number of tokens to predict using the draft model.",
|
||||
)
|
||||
# KV Cache Quantization
|
||||
type_k: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Type of the key cache quantization.",
|
||||
)
|
||||
type_v: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Type of the value cache quantization.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
|
|
Loading…
Reference in a new issue