Add bindings for custom_rope

This commit is contained in:
randoentity 2023-07-09 09:05:16 +02:00
parent 8e0f6253db
commit 3f8f276f9f
3 changed files with 9 additions and 1 deletions

View file

@ -205,6 +205,8 @@ class Llama:
model_path: str, model_path: str,
# NOTE: These parameters are likely to change in the future. # NOTE: These parameters are likely to change in the future.
n_ctx: int = 512, n_ctx: int = 512,
rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0,
n_parts: int = -1, n_parts: int = -1,
n_gpu_layers: int = 0, n_gpu_layers: int = 0,
seed: int = 1337, seed: int = 1337,
@ -227,6 +229,8 @@ class Llama:
Args: Args:
model_path: Path to the model. model_path: Path to the model.
n_ctx: Maximum context size. n_ctx: Maximum context size.
rope_freq_base: RoPE base frequency.
rope_freq_scale: RoPE frequency scale.
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
seed: Random seed. -1 for random. seed: Random seed. -1 for random.
f16_kv: Use half-precision for key/value cache. f16_kv: Use half-precision for key/value cache.
@ -253,6 +257,8 @@ class Llama:
self.params = llama_cpp.llama_context_default_params() self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx self.params.n_ctx = n_ctx
self.params.rope_freq_base = rope_freq_base
self.params.rope_freq_scale = rope_freq_scale
self.params.n_gpu_layers = n_gpu_layers self.params.n_gpu_layers = n_gpu_layers
self.params.seed = seed self.params.seed = seed
self.params.f16_kv = f16_kv self.params.f16_kv = f16_kv

View file

@ -184,6 +184,8 @@ class llama_context_params(Structure):
_fields_ = [ _fields_ = [
("seed", c_uint32), ("seed", c_uint32),
("n_ctx", c_int32), ("n_ctx", c_int32),
("rope_freq_base", c_float),
("rope_freq_scale", c_float),
("n_batch", c_int32), ("n_batch", c_int32),
("n_gpu_layers", c_int32), ("n_gpu_layers", c_int32),
("main_gpu", c_int32), ("main_gpu", c_int32),

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 1d1630996920f889cdc08de26cebf2415958540e Subproject commit a3b4d932859f4e51ed716bfa1f07e2d2eede2c23