diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 24448ec..fbe3584 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -17,7 +17,7 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: "true" - name: Set up Python ${{ matrix.python-version }} diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c9ea90f..705a4b2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -230,8 +230,14 @@ class Llama: n_batch: int = 512, n_threads: Optional[int] = None, n_threads_batch: Optional[int] = None, + rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED, rope_freq_base: float = 0.0, rope_freq_scale: float = 0.0, + yarn_ext_factor: float = float("nan"), + yarn_attn_factor: float = 1.0, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + yarn_orig_ctx: int = 0, mul_mat_q: bool = True, f16_kv: bool = True, logits_all: bool = False, @@ -255,30 +261,30 @@ class Llama: Args: model_path: Path to the model. - seed: Random seed. -1 for random. - n_ctx: Maximum context size. - n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded. - main_gpu: Main GPU to use. - tensor_split: Optional list of floats to split the model across multiple GPUs. If None, the model is not split. - rope_freq_base: Base frequency for rope sampling. - rope_freq_scale: Scale factor for rope sampling. - low_vram: Use low VRAM mode. - mul_mat_q: if true, use experimental mul_mat_q kernels - f16_kv: Use half-precision for key/value cache. - logits_all: Return logits for all tokens, not just the last token. + main_gpu: The GPU that is used for scratch and small tensors. + tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split. vocab_only: Only load the vocabulary no weights. use_mmap: Use mmap if possible. use_mlock: Force the system to keep the model in RAM. - embedding: Embedding mode only. + seed: Random seed. -1 for random. + n_ctx: Context size. + n_batch: Batch size for prompt processing (must be >= 32 to use BLAS) n_threads: Number of threads to use. If None, the number of threads is automatically determined. + n_threads_batch: Number of threads to use for batch processing. If None, use n_threads. + rope_scaling_type: Type of rope scaling to use. + rope_freq_base: Base frequency for rope sampling. + rope_freq_scale: Scale factor for rope sampling. + mul_mat_q: if true, use experimental mul_mat_q kernels + f16_kv: Use half-precision for key/value cache. + logits_all: Return logits for all tokens, not just the last token. + embedding: Embedding mode only. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_path: Path to a LoRA file to apply to the model. numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) chat_format: String specifying the chat format to use when calling create_chat_completion. verbose: Print verbose output to stderr. - kwargs: Unused keyword arguments (for additional backwards compatibility). Raises: ValueError: If the model path does not exist. @@ -332,12 +338,30 @@ class Llama: self.context_params.n_batch = self.n_batch self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch + self.context_params.rope_scaling_type = ( + rope_scaling_type if rope_scaling_type is not None else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED + ) self.context_params.rope_freq_base = ( rope_freq_base if rope_freq_base != 0.0 else 0 ) self.context_params.rope_freq_scale = ( rope_freq_scale if rope_freq_scale != 0.0 else 0 ) + self.context_params.yarn_ext_factor = ( + yarn_ext_factor if yarn_ext_factor != 0.0 else 0 + ) + self.context_params.yarn_attn_factor = ( + yarn_attn_factor if yarn_attn_factor != 0.0 else 0 + ) + self.context_params.yarn_beta_fast = ( + yarn_beta_fast if yarn_beta_fast != 0.0 else 0 + ) + self.context_params.yarn_beta_slow = ( + yarn_beta_slow if yarn_beta_slow != 0.0 else 0 + ) + self.context_params.yarn_orig_ctx = ( + yarn_orig_ctx if yarn_orig_ctx != 0 else 0 + ) self.context_params.mul_mat_q = mul_mat_q self.context_params.f16_kv = f16_kv self.context_params.logits_all = logits_all @@ -1671,8 +1695,14 @@ class Llama: n_batch=self.n_batch, n_threads=self.context_params.n_threads, n_threads_batch=self.context_params.n_threads_batch, + rope_scaling_type=self.context_params.rope_scaling_type, rope_freq_base=self.context_params.rope_freq_base, rope_freq_scale=self.context_params.rope_freq_scale, + yarn_ext_factor=self.context_params.yarn_ext_factor, + yarn_attn_factor=self.context_params.yarn_attn_factor, + yarn_beta_fast=self.context_params.yarn_beta_fast, + yarn_beta_slow=self.context_params.yarn_beta_slow, + yarn_orig_ctx=self.context_params.yarn_orig_ctx, mul_mat_q=self.context_params.mul_mat_q, f16_kv=self.context_params.f16_kv, logits_all=self.context_params.logits_all, @@ -1709,6 +1739,12 @@ class Llama: n_threads_batch=state["n_threads_batch"], rope_freq_base=state["rope_freq_base"], rope_freq_scale=state["rope_freq_scale"], + rope_scaling_type=state["rope_scaling_type"], + yarn_ext_factor=state["yarn_ext_factor"], + yarn_attn_factor=state["yarn_attn_factor"], + yarn_beta_fast=state["yarn_beta_fast"], + yarn_beta_slow=state["yarn_beta_slow"], + yarn_orig_ctx=state["yarn_orig_ctx"], mul_mat_q=state["mul_mat_q"], f16_kv=state["f16_kv"], logits_all=state["logits_all"], diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index ba4e26b..b6216a5 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -192,6 +192,18 @@ LLAMA_FTYPE_MOSTLY_Q5_K_M = 17 LLAMA_FTYPE_MOSTLY_Q6_K = 18 LLAMA_FTYPE_GUESSED = 1024 +# enum llama_rope_scaling_type { +# LLAMA_ROPE_SCALING_UNSPECIFIED = -1, +# LLAMA_ROPE_SCALING_NONE = 0, +# LLAMA_ROPE_SCALING_LINEAR = 1, +# LLAMA_ROPE_SCALING_YARN = 2, +# LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, +# }; +LLAMA_ROPE_SCALING_UNSPECIFIED = -1 +LLAMA_ROPE_SCALING_NONE = 0 +LLAMA_ROPE_SCALING_LINEAR = 1 +LLAMA_ROPE_SCALING_YARN = 2 +LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN # typedef struct llama_token_data { # llama_token id; // token id @@ -308,10 +320,16 @@ class llama_model_params(Structure): # uint32_t n_batch; // prompt processing maximum batch size # uint32_t n_threads; // number of threads to use for generation # uint32_t n_threads_batch; // number of threads to use for batch processing +# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # // ref: https://github.com/ggerganov/llama.cpp/pull/2054 -# float rope_freq_base; // RoPE base frequency, 0 = from model -# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model +# float rope_freq_base; // RoPE base frequency, 0 = from model +# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model +# float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model +# float yarn_attn_factor; // YaRN magnitude scaling factor +# float yarn_beta_fast; // YaRN low correction dim +# float yarn_beta_slow; // YaRN high correction dim +# uint32_t yarn_orig_ctx; // YaRN original context size # // Keep the booleans together to avoid misalignment during copy-by-value. @@ -327,8 +345,14 @@ class llama_context_params(Structure): ("n_batch", c_uint32), ("n_threads", c_uint32), ("n_threads_batch", c_uint32), + ("rope_scaling_type", c_int8), ("rope_freq_base", c_float), ("rope_freq_scale", c_float), + ("yarn_ext_factor", c_float), + ("yarn_attn_factor", c_float), + ("yarn_beta_fast", c_float), + ("yarn_beta_slow", c_float), + ("yarn_orig_ctx", c_uint32), ("mul_mat_q", c_bool), ("f16_kv", c_bool), ("logits_all", c_bool), diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index f8d8c76..73b660a 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -41,11 +41,7 @@ class Settings(BaseSettings): default=None, description="The alias of the model to use for generating completions.", ) - seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.") - n_ctx: int = Field(default=2048, ge=1, description="The context size.") - n_batch: int = Field( - default=512, ge=1, description="The batch size to use per eval." - ) + # Model Params n_gpu_layers: int = Field( default=0, ge=-1, @@ -60,17 +56,6 @@ class Settings(BaseSettings): default=None, description="Split layers across multiple GPUs in proportion.", ) - rope_freq_base: float = Field( - default=0.0, description="RoPE base frequency" - ) - rope_freq_scale: float = Field( - default=0.0, description="RoPE frequency scaling factor" - ) - mul_mat_q: bool = Field( - default=True, description="if true, use experimental mul_mat_q kernels" - ) - f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") - logits_all: bool = Field(default=True, description="Whether to return logits.") vocab_only: bool = Field( default=False, description="Whether to only return the vocabulary." ) @@ -82,17 +67,59 @@ class Settings(BaseSettings): default=llama_cpp.llama_mlock_supported(), description="Use mlock.", ) - embedding: bool = Field(default=True, description="Whether to use embeddings.") + # Context Params + seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.") + n_ctx: int = Field(default=2048, ge=1, description="The context size.") + n_batch: int = Field( + default=512, ge=1, description="The batch size to use per eval." + ) n_threads: int = Field( default=max(multiprocessing.cpu_count() // 2, 1), ge=1, description="The number of threads to use.", ) + n_threads_batch: int = Field( + default=max(multiprocessing.cpu_count() // 2, 1), + ge=0, + description="The number of threads to use when batch processing.", + ) + rope_scaling_type: int = Field( + default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED + ) + rope_freq_base: float = Field( + default=0.0, description="RoPE base frequency" + ) + rope_freq_scale: float = Field( + default=0.0, description="RoPE frequency scaling factor" + ) + yarn_ext_factor: float = Field( + default=float("nan") + ) + yarn_attn_factor: float = Field( + default=1.0 + ) + yarn_beta_fast: float = Field( + default=32.0 + ) + yarn_beta_slow: float = Field( + default=1.0 + ) + yarn_orig_ctx: int = Field( + default=0 + ) + mul_mat_q: bool = Field( + default=True, description="if true, use experimental mul_mat_q kernels" + ) + f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") + logits_all: bool = Field(default=True, description="Whether to return logits.") + embedding: bool = Field(default=True, description="Whether to use embeddings.") + # Sampling Params last_n_tokens_size: int = Field( default=64, ge=0, description="Last n tokens to keep for repeat penalty calculation.", ) + # LoRA Params lora_base: Optional[str] = Field( default=None, description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model." @@ -101,14 +128,17 @@ class Settings(BaseSettings): default=None, description="Path to a LoRA file to apply to the model.", ) + # Backend Params numa: bool = Field( default=False, description="Enable NUMA support.", ) + # Chat Format Params chat_format: str = Field( default="llama-2", description="Chat format to use.", ) + # Cache Params cache: bool = Field( default=False, description="Use a cache to reduce processing times for evaluated prompts.", @@ -121,9 +151,11 @@ class Settings(BaseSettings): default=2 << 30, description="The size of the cache in bytes. Only used if cache is True.", ) + # Misc verbose: bool = Field( default=True, description="Whether to print debug information." ) + # Server Params host: str = Field(default="localhost", description="Listen address") port: int = Field(default=8000, description="Listen port") interrupt_requests: bool = Field( @@ -345,27 +377,41 @@ def create_app(settings: Optional[Settings] = None): global llama llama = llama_cpp.Llama( model_path=settings.model, - seed=settings.seed, - n_ctx=settings.n_ctx, - n_batch=settings.n_batch, + # Model Params n_gpu_layers=settings.n_gpu_layers, main_gpu=settings.main_gpu, tensor_split=settings.tensor_split, - rope_freq_base=settings.rope_freq_base, - rope_freq_scale=settings.rope_freq_scale, - mul_mat_q=settings.mul_mat_q, - f16_kv=settings.f16_kv, - logits_all=settings.logits_all, vocab_only=settings.vocab_only, use_mmap=settings.use_mmap, use_mlock=settings.use_mlock, - embedding=settings.embedding, + # Context Params + seed=settings.seed, + n_ctx=settings.n_ctx, + n_batch=settings.n_batch, n_threads=settings.n_threads, + n_threads_batch=settings.n_threads_batch, + rope_scaling_type=settings.rope_scaling_type, + rope_freq_base=settings.rope_freq_base, + rope_freq_scale=settings.rope_freq_scale, + yarn_ext_factor=settings.yarn_ext_factor, + yarn_attn_factor=settings.yarn_attn_factor, + yarn_beta_fast=settings.yarn_beta_fast, + yarn_beta_slow=settings.yarn_beta_slow, + yarn_orig_ctx=settings.yarn_orig_ctx, + mul_mat_q=settings.mul_mat_q, + f16_kv=settings.f16_kv, + logits_all=settings.logits_all, + embedding=settings.embedding, + # Sampling Params last_n_tokens_size=settings.last_n_tokens_size, + # LoRA Params lora_base=settings.lora_base, lora_path=settings.lora_path, + # Backend Params numa=settings.numa, + # Chat Format Params chat_format=settings.chat_format, + # Misc verbose=settings.verbose, ) if settings.cache: diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 5033796..4ff1046 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 50337961a678fce4081554b24e56e86b67660163 +Subproject commit 4ff1046d75e64f0e556d8dcd930ea25c23eb8b18