Fix rope scaling defaults (#767)
* Fix rope scale with backwards compatibility * Fix defaults * Fix op * Remove backwards compatibility * Check single val
This commit is contained in:
parent
a72efc77de
commit
a945404b4a
1 changed files with 8 additions and 7 deletions
|
@ -229,8 +229,8 @@ class Llama:
|
||||||
n_batch: int = 512,
|
n_batch: int = 512,
|
||||||
n_threads: Optional[int] = None,
|
n_threads: Optional[int] = None,
|
||||||
n_threads_batch: Optional[int] = None,
|
n_threads_batch: Optional[int] = None,
|
||||||
rope_freq_base: float = 10000.0,
|
rope_freq_base: float = 0.0,
|
||||||
rope_freq_scale: float = 1.0,
|
rope_freq_scale: float = 0.0,
|
||||||
mul_mat_q: bool = True,
|
mul_mat_q: bool = True,
|
||||||
f16_kv: bool = True,
|
f16_kv: bool = True,
|
||||||
logits_all: bool = False,
|
logits_all: bool = False,
|
||||||
|
@ -282,7 +282,6 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
A Llama instance.
|
A Llama instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self.numa = numa
|
self.numa = numa
|
||||||
|
@ -320,7 +319,6 @@ class Llama:
|
||||||
self.n_threads_batch = n_threads_batch or max(
|
self.n_threads_batch = n_threads_batch or max(
|
||||||
multiprocessing.cpu_count() // 2, 1
|
multiprocessing.cpu_count() // 2, 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Context Params
|
# Context Params
|
||||||
self.context_params = llama_cpp.llama_context_default_params()
|
self.context_params = llama_cpp.llama_context_default_params()
|
||||||
self.context_params.seed = seed
|
self.context_params.seed = seed
|
||||||
|
@ -328,8 +326,12 @@ class Llama:
|
||||||
self.context_params.n_batch = self.n_batch
|
self.context_params.n_batch = self.n_batch
|
||||||
self.context_params.n_threads = self.n_threads
|
self.context_params.n_threads = self.n_threads
|
||||||
self.context_params.n_threads_batch = self.n_threads_batch
|
self.context_params.n_threads_batch = self.n_threads_batch
|
||||||
self.context_params.rope_freq_base = rope_freq_base
|
self.context_params.rope_freq_base = (
|
||||||
self.context_params.rope_freq_scale = rope_freq_scale
|
rope_freq_base if rope_freq_base != 0.0 else 0
|
||||||
|
)
|
||||||
|
self.context_params.rope_freq_scale = (
|
||||||
|
rope_freq_scale if rope_freq_scale != 0.0 else 0
|
||||||
|
)
|
||||||
self.context_params.mul_mat_q = mul_mat_q
|
self.context_params.mul_mat_q = mul_mat_q
|
||||||
self.context_params.f16_kv = f16_kv
|
self.context_params.f16_kv = f16_kv
|
||||||
self.context_params.logits_all = logits_all
|
self.context_params.logits_all = logits_all
|
||||||
|
@ -338,7 +340,6 @@ class Llama:
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
|
|
||||||
|
|
||||||
self.cache: Optional[BaseLlamaCache] = None
|
self.cache: Optional[BaseLlamaCache] = None
|
||||||
|
|
||||||
self.lora_base = lora_base
|
self.lora_base = lora_base
|
||||||
|
|
Loading…
Reference in a new issue