Bugfix: Remove f16_kv, add offload_kqv field (#1019)

F16_KV appears to have been removed here: af99c6fbfc

This addresses two issues:

 - #995 which just requests to add the KV cache offloading param
 - #1006 a NULL ptr exception when using the embeddings (introduced by
   leaving f16_kv in the fields struct)
This commit is contained in:
Brandon Roberts 2023-12-18 12:27:11 -07:00 committed by GitHub
parent 37da8e863a
commit 62944df142
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 10 deletions

View file

@ -751,7 +751,6 @@ class Llama:
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
mul_mat_q: bool = True,
f16_kv: bool = True,
logits_all: bool = False,
embedding: bool = False,
# Sampling Params
@ -817,7 +816,6 @@ class Llama:
yarn_beta_fast: YaRN low correction dim
yarn_beta_slow: YaRN high correction dim
yarn_orig_ctx: YaRN original context size
f16_kv: Use fp16 for KV cache, fp32 otherwise
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
@ -904,7 +902,6 @@ class Llama:
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q
# self.context_params.f16_kv = f16_kv
self.context_params.logits_all = logits_all
self.context_params.embedding = embedding
@ -2155,7 +2152,6 @@ class Llama:
yarn_beta_slow=self.context_params.yarn_beta_slow,
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
mul_mat_q=self.context_params.mul_mat_q,
f16_kv=self.context_params.f16_kv,
logits_all=self.context_params.logits_all,
embedding=self.context_params.embedding,
# Sampling Params
@ -2198,7 +2194,6 @@ class Llama:
yarn_beta_slow=state["yarn_beta_slow"],
yarn_orig_ctx=state["yarn_orig_ctx"],
mul_mat_q=state["mul_mat_q"],
f16_kv=state["f16_kv"],
logits_all=state["logits_all"],
embedding=state["embedding"],
# Sampling Params

View file

@ -432,9 +432,9 @@ class llama_context_params(Structure):
type_k (int): data type for K cache
type_v (int): data type for V cache
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
f16_kv (bool): use fp16 for KV cache, fp32 otherwise
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
embedding (bool): embedding mode only"""
embedding (bool): embedding mode only
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU"""
_fields_ = [
("seed", c_uint32),
("n_ctx", c_uint32),
@ -452,9 +452,9 @@ class llama_context_params(Structure):
("type_k", c_int),
("type_v", c_int),
("mul_mat_q", c_bool),
("f16_kv", c_bool),
("logits_all", c_bool),
("embedding", c_bool),
("offload_kqv", c_bool),
]

View file

@ -98,7 +98,6 @@ class Settings(BaseSettings):
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=True, description="Whether to use embeddings.")
# Sampling Params
@ -408,7 +407,6 @@ def create_app(settings: Optional[Settings] = None):
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
mul_mat_q=settings.mul_mat_q,
f16_kv=settings.f16_kv,
logits_all=settings.logits_all,
embedding=settings.embedding,
# Sampling Params