Update llama.cpp
Fix build examples Exclude examples directory Revert cmake changes Try actions/checkout@v4 Try to update submodules Revert Update llama.cpp Fix build examples Exclude examples directory Revert cmake changes Try actions/checkout@v4 Try to update submodules Revert
This commit is contained in:
parent
ddbd10c442
commit
fa83cc5f9c
5 changed files with 149 additions and 43 deletions
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
|
@ -17,7 +17,7 @@ jobs:
|
|||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: "true"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
|
|
@ -230,8 +230,14 @@ class Llama:
|
|||
n_batch: int = 512,
|
||||
n_threads: Optional[int] = None,
|
||||
n_threads_batch: Optional[int] = None,
|
||||
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED,
|
||||
rope_freq_base: float = 0.0,
|
||||
rope_freq_scale: float = 0.0,
|
||||
yarn_ext_factor: float = float("nan"),
|
||||
yarn_attn_factor: float = 1.0,
|
||||
yarn_beta_fast: float = 32.0,
|
||||
yarn_beta_slow: float = 1.0,
|
||||
yarn_orig_ctx: int = 0,
|
||||
mul_mat_q: bool = True,
|
||||
f16_kv: bool = True,
|
||||
logits_all: bool = False,
|
||||
|
@ -255,30 +261,30 @@ class Llama:
|
|||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
seed: Random seed. -1 for random.
|
||||
n_ctx: Maximum context size.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
|
||||
main_gpu: Main GPU to use.
|
||||
tensor_split: Optional list of floats to split the model across multiple GPUs. If None, the model is not split.
|
||||
rope_freq_base: Base frequency for rope sampling.
|
||||
rope_freq_scale: Scale factor for rope sampling.
|
||||
low_vram: Use low VRAM mode.
|
||||
mul_mat_q: if true, use experimental mul_mat_q kernels
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
main_gpu: The GPU that is used for scratch and small tensors.
|
||||
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
|
||||
vocab_only: Only load the vocabulary no weights.
|
||||
use_mmap: Use mmap if possible.
|
||||
use_mlock: Force the system to keep the model in RAM.
|
||||
embedding: Embedding mode only.
|
||||
seed: Random seed. -1 for random.
|
||||
n_ctx: Context size.
|
||||
n_batch: Batch size for prompt processing (must be >= 32 to use BLAS)
|
||||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||
n_threads_batch: Number of threads to use for batch processing. If None, use n_threads.
|
||||
rope_scaling_type: Type of rope scaling to use.
|
||||
rope_freq_base: Base frequency for rope sampling.
|
||||
rope_freq_scale: Scale factor for rope sampling.
|
||||
mul_mat_q: if true, use experimental mul_mat_q kernels
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
embedding: Embedding mode only.
|
||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
||||
lora_path: Path to a LoRA file to apply to the model.
|
||||
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
|
||||
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
||||
verbose: Print verbose output to stderr.
|
||||
kwargs: Unused keyword arguments (for additional backwards compatibility).
|
||||
|
||||
Raises:
|
||||
ValueError: If the model path does not exist.
|
||||
|
@ -332,12 +338,30 @@ class Llama:
|
|||
self.context_params.n_batch = self.n_batch
|
||||
self.context_params.n_threads = self.n_threads
|
||||
self.context_params.n_threads_batch = self.n_threads_batch
|
||||
self.context_params.rope_scaling_type = (
|
||||
rope_scaling_type if rope_scaling_type is not None else llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
|
||||
)
|
||||
self.context_params.rope_freq_base = (
|
||||
rope_freq_base if rope_freq_base != 0.0 else 0
|
||||
)
|
||||
self.context_params.rope_freq_scale = (
|
||||
rope_freq_scale if rope_freq_scale != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_ext_factor = (
|
||||
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_attn_factor = (
|
||||
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_beta_fast = (
|
||||
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_beta_slow = (
|
||||
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_orig_ctx = (
|
||||
yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
||||
)
|
||||
self.context_params.mul_mat_q = mul_mat_q
|
||||
self.context_params.f16_kv = f16_kv
|
||||
self.context_params.logits_all = logits_all
|
||||
|
@ -1671,8 +1695,14 @@ class Llama:
|
|||
n_batch=self.n_batch,
|
||||
n_threads=self.context_params.n_threads,
|
||||
n_threads_batch=self.context_params.n_threads_batch,
|
||||
rope_scaling_type=self.context_params.rope_scaling_type,
|
||||
rope_freq_base=self.context_params.rope_freq_base,
|
||||
rope_freq_scale=self.context_params.rope_freq_scale,
|
||||
yarn_ext_factor=self.context_params.yarn_ext_factor,
|
||||
yarn_attn_factor=self.context_params.yarn_attn_factor,
|
||||
yarn_beta_fast=self.context_params.yarn_beta_fast,
|
||||
yarn_beta_slow=self.context_params.yarn_beta_slow,
|
||||
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
|
||||
mul_mat_q=self.context_params.mul_mat_q,
|
||||
f16_kv=self.context_params.f16_kv,
|
||||
logits_all=self.context_params.logits_all,
|
||||
|
@ -1709,6 +1739,12 @@ class Llama:
|
|||
n_threads_batch=state["n_threads_batch"],
|
||||
rope_freq_base=state["rope_freq_base"],
|
||||
rope_freq_scale=state["rope_freq_scale"],
|
||||
rope_scaling_type=state["rope_scaling_type"],
|
||||
yarn_ext_factor=state["yarn_ext_factor"],
|
||||
yarn_attn_factor=state["yarn_attn_factor"],
|
||||
yarn_beta_fast=state["yarn_beta_fast"],
|
||||
yarn_beta_slow=state["yarn_beta_slow"],
|
||||
yarn_orig_ctx=state["yarn_orig_ctx"],
|
||||
mul_mat_q=state["mul_mat_q"],
|
||||
f16_kv=state["f16_kv"],
|
||||
logits_all=state["logits_all"],
|
||||
|
|
|
@ -192,6 +192,18 @@ LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
|
|||
LLAMA_FTYPE_MOSTLY_Q6_K = 18
|
||||
LLAMA_FTYPE_GUESSED = 1024
|
||||
|
||||
# enum llama_rope_scaling_type {
|
||||
# LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
|
||||
# LLAMA_ROPE_SCALING_NONE = 0,
|
||||
# LLAMA_ROPE_SCALING_LINEAR = 1,
|
||||
# LLAMA_ROPE_SCALING_YARN = 2,
|
||||
# LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||
# };
|
||||
LLAMA_ROPE_SCALING_UNSPECIFIED = -1
|
||||
LLAMA_ROPE_SCALING_NONE = 0
|
||||
LLAMA_ROPE_SCALING_LINEAR = 1
|
||||
LLAMA_ROPE_SCALING_YARN = 2
|
||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN
|
||||
|
||||
# typedef struct llama_token_data {
|
||||
# llama_token id; // token id
|
||||
|
@ -308,10 +320,16 @@ class llama_model_params(Structure):
|
|||
# uint32_t n_batch; // prompt processing maximum batch size
|
||||
# uint32_t n_threads; // number of threads to use for generation
|
||||
# uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
# float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
# float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
|
||||
# float yarn_attn_factor; // YaRN magnitude scaling factor
|
||||
# float yarn_beta_fast; // YaRN low correction dim
|
||||
# float yarn_beta_slow; // YaRN high correction dim
|
||||
# uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
|
||||
|
||||
# // Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
|
@ -327,8 +345,14 @@ class llama_context_params(Structure):
|
|||
("n_batch", c_uint32),
|
||||
("n_threads", c_uint32),
|
||||
("n_threads_batch", c_uint32),
|
||||
("rope_scaling_type", c_int8),
|
||||
("rope_freq_base", c_float),
|
||||
("rope_freq_scale", c_float),
|
||||
("yarn_ext_factor", c_float),
|
||||
("yarn_attn_factor", c_float),
|
||||
("yarn_beta_fast", c_float),
|
||||
("yarn_beta_slow", c_float),
|
||||
("yarn_orig_ctx", c_uint32),
|
||||
("mul_mat_q", c_bool),
|
||||
("f16_kv", c_bool),
|
||||
("logits_all", c_bool),
|
||||
|
|
|
@ -41,11 +41,7 @@ class Settings(BaseSettings):
|
|||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
)
|
||||
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
# Model Params
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
|
@ -60,17 +56,6 @@ class Settings(BaseSettings):
|
|||
default=None,
|
||||
description="Split layers across multiple GPUs in proportion.",
|
||||
)
|
||||
rope_freq_base: float = Field(
|
||||
default=0.0, description="RoPE base frequency"
|
||||
)
|
||||
rope_freq_scale: float = Field(
|
||||
default=0.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
mul_mat_q: bool = Field(
|
||||
default=True, description="if true, use experimental mul_mat_q kernels"
|
||||
)
|
||||
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
|
||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||
vocab_only: bool = Field(
|
||||
default=False, description="Whether to only return the vocabulary."
|
||||
)
|
||||
|
@ -82,17 +67,59 @@ class Settings(BaseSettings):
|
|||
default=llama_cpp.llama_mlock_supported(),
|
||||
description="Use mlock.",
|
||||
)
|
||||
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
||||
# Context Params
|
||||
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
n_threads: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=1,
|
||||
description="The number of threads to use.",
|
||||
)
|
||||
n_threads_batch: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=0,
|
||||
description="The number of threads to use when batch processing.",
|
||||
)
|
||||
rope_scaling_type: int = Field(
|
||||
default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
|
||||
)
|
||||
rope_freq_base: float = Field(
|
||||
default=0.0, description="RoPE base frequency"
|
||||
)
|
||||
rope_freq_scale: float = Field(
|
||||
default=0.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
yarn_ext_factor: float = Field(
|
||||
default=float("nan")
|
||||
)
|
||||
yarn_attn_factor: float = Field(
|
||||
default=1.0
|
||||
)
|
||||
yarn_beta_fast: float = Field(
|
||||
default=32.0
|
||||
)
|
||||
yarn_beta_slow: float = Field(
|
||||
default=1.0
|
||||
)
|
||||
yarn_orig_ctx: int = Field(
|
||||
default=0
|
||||
)
|
||||
mul_mat_q: bool = Field(
|
||||
default=True, description="if true, use experimental mul_mat_q kernels"
|
||||
)
|
||||
f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.")
|
||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
||||
# Sampling Params
|
||||
last_n_tokens_size: int = Field(
|
||||
default=64,
|
||||
ge=0,
|
||||
description="Last n tokens to keep for repeat penalty calculation.",
|
||||
)
|
||||
# LoRA Params
|
||||
lora_base: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model."
|
||||
|
@ -101,14 +128,17 @@ class Settings(BaseSettings):
|
|||
default=None,
|
||||
description="Path to a LoRA file to apply to the model.",
|
||||
)
|
||||
# Backend Params
|
||||
numa: bool = Field(
|
||||
default=False,
|
||||
description="Enable NUMA support.",
|
||||
)
|
||||
# Chat Format Params
|
||||
chat_format: str = Field(
|
||||
default="llama-2",
|
||||
description="Chat format to use.",
|
||||
)
|
||||
# Cache Params
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
description="Use a cache to reduce processing times for evaluated prompts.",
|
||||
|
@ -121,9 +151,11 @@ class Settings(BaseSettings):
|
|||
default=2 << 30,
|
||||
description="The size of the cache in bytes. Only used if cache is True.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
)
|
||||
# Server Params
|
||||
host: str = Field(default="localhost", description="Listen address")
|
||||
port: int = Field(default=8000, description="Listen port")
|
||||
interrupt_requests: bool = Field(
|
||||
|
@ -345,27 +377,41 @@ def create_app(settings: Optional[Settings] = None):
|
|||
global llama
|
||||
llama = llama_cpp.Llama(
|
||||
model_path=settings.model,
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
n_batch=settings.n_batch,
|
||||
# Model Params
|
||||
n_gpu_layers=settings.n_gpu_layers,
|
||||
main_gpu=settings.main_gpu,
|
||||
tensor_split=settings.tensor_split,
|
||||
rope_freq_base=settings.rope_freq_base,
|
||||
rope_freq_scale=settings.rope_freq_scale,
|
||||
mul_mat_q=settings.mul_mat_q,
|
||||
f16_kv=settings.f16_kv,
|
||||
logits_all=settings.logits_all,
|
||||
vocab_only=settings.vocab_only,
|
||||
use_mmap=settings.use_mmap,
|
||||
use_mlock=settings.use_mlock,
|
||||
embedding=settings.embedding,
|
||||
# Context Params
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
n_batch=settings.n_batch,
|
||||
n_threads=settings.n_threads,
|
||||
n_threads_batch=settings.n_threads_batch,
|
||||
rope_scaling_type=settings.rope_scaling_type,
|
||||
rope_freq_base=settings.rope_freq_base,
|
||||
rope_freq_scale=settings.rope_freq_scale,
|
||||
yarn_ext_factor=settings.yarn_ext_factor,
|
||||
yarn_attn_factor=settings.yarn_attn_factor,
|
||||
yarn_beta_fast=settings.yarn_beta_fast,
|
||||
yarn_beta_slow=settings.yarn_beta_slow,
|
||||
yarn_orig_ctx=settings.yarn_orig_ctx,
|
||||
mul_mat_q=settings.mul_mat_q,
|
||||
f16_kv=settings.f16_kv,
|
||||
logits_all=settings.logits_all,
|
||||
embedding=settings.embedding,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
lora_base=settings.lora_base,
|
||||
lora_path=settings.lora_path,
|
||||
# Backend Params
|
||||
numa=settings.numa,
|
||||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
if settings.cache:
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 50337961a678fce4081554b24e56e86b67660163
|
||||
Subproject commit 4ff1046d75e64f0e556d8dcd930ea25c23eb8b18
|
Loading…
Reference in a new issue