diff --git a/CHANGELOG.md b/CHANGELOG.md index 5061247..c0748ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.29] + +- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b +- feat: Add split_mode option by @abetlen in 84615adbc6855c8384807c42f0130f9a1763f99d +- feat: Implement GGUF metadata KV overrides by @phiharri in #1011 +- fix: Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor by @yieldthought in #1012 +- fix: Fix low_level_api_chat_cpp example to match current API by @aniljava in #1086 +- fix: Fix Pydantic model parsing by @DeNeutoy in #1087 + ## [0.2.28] - feat: Update llama.cpp to ggerganov/llama.cpp@6efb8eb30e7025b168f3fda3ff83b9b386428ad6 diff --git a/README.md b/README.md index b2e879e..ad5d0f1 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ This package provides: - High-level Python API for text completion - OpenAI-like API - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp) + - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html) - OpenAI compatible web server - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion) - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling) diff --git a/examples/low_level_api/common.py b/examples/low_level_api/common.py index 55d08db..1a51525 100644 --- a/examples/low_level_api/common.py +++ b/examples/low_level_api/common.py @@ -106,7 +106,7 @@ def gpt_params_parse(argv = None): parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta") parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model") - parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt") + parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt") parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file") parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session") parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix") diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 44b6d4a..02c09af 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -62,7 +62,7 @@ specified) expect poor results""", file=sys.stderr) self.multibyte_fix = [] # model load - self.lparams = llama_cpp.llama_context_default_params() + self.lparams = llama_cpp.llama_model_default_params() self.lparams.n_ctx = self.params.n_ctx self.lparams.n_parts = self.params.n_parts self.lparams.seed = self.params.seed @@ -72,7 +72,11 @@ specified) expect poor results""", file=sys.stderr) self.model = llama_cpp.llama_load_model_from_file( self.params.model.encode("utf8"), self.lparams) - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams) + + # Context Params. + self.cparams = llama_cpp.llama_context_default_params() + + self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams) if (not self.ctx): raise RuntimeError(f"error: failed to load model '{self.params.model}'") @@ -244,7 +248,7 @@ n_keep = {self.params.n_keep} # tokenize a prompt def _tokenize(self, prompt, bos=True): _arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))() - _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos) + _n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False) return _arr[:_n] def set_color(self, c): @@ -304,7 +308,7 @@ n_keep = {self.params.n_keep} self.n_past += n_eval""" if (llama_cpp.llama_eval( - self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads + self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past ) != 0): raise Exception("Failed to llama_eval!") @@ -332,7 +336,7 @@ n_keep = {self.params.n_keep} id = 0 logits = llama_cpp.llama_get_logits(self.ctx) - n_vocab = llama_cpp.llama_n_vocab(self.ctx) + n_vocab = llama_cpp.llama_n_vocab(self.model) # Apply params.logit_bias map for key, value in self.params.logit_bias.items(): @@ -349,12 +353,20 @@ n_keep = {self.params.n_keep} last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx) _arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:]) - llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, - _arr, - last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty)) - llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, - _arr, - last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty)) + llama_cpp.llama_sample_repetition_penalties( + ctx=self.ctx, + candidates=candidates_p, + last_tokens_data = _arr, + penalty_last_n = last_n_repeat, + penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty), + penalty_freq = llama_cpp.c_float(self.params.frequency_penalty), + penalty_present = llama_cpp.c_float(self.params.presence_penalty), + ) + + # NOT PRESENT IN CURRENT VERSION ? + # llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p, + # _arr, + # last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty)) if not self.params.penalize_nl: logits[llama_cpp.llama_token_nl()] = nl_logit @@ -473,7 +485,7 @@ n_keep = {self.params.n_keep} def token_to_str(self, token_id: int) -> bytes: size = 32 buffer = (ctypes.c_char * size)() - n = llama_cpp.llama_token_to_piece_with_model( + n = llama_cpp.llama_token_to_piece( self.model, llama_cpp.llama_token(token_id), buffer, size) assert n <= size return bytes(buffer[:n]) @@ -532,6 +544,9 @@ n_keep = {self.params.n_keep} print(i,end="",flush=True) self.params.input_echo = False + # Using string instead of tokens to check for antiprompt, + # It is more reliable than tokens for interactive mode. + generated_str = "" while self.params.interactive: self.set_color(util.CONSOLE_COLOR_USER_INPUT) if (self.params.instruct): @@ -546,6 +561,10 @@ n_keep = {self.params.n_keep} try: for i in self.output(): print(i,end="",flush=True) + generated_str += i + for ap in self.params.antiprompt: + if generated_str.endswith(ap): + raise KeyboardInterrupt except KeyboardInterrupt: self.set_color(util.CONSOLE_COLOR_DEFAULT) if not self.params.instruct: @@ -561,7 +580,7 @@ if __name__ == "__main__": time_now = datetime.now() prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}. {AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision. -There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other. +Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other. The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. The transcript only includes text, it does not include markup like HTML and Markdown. @@ -575,8 +594,11 @@ The transcript only includes text, it does not include markup like HTML and Mark {AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. {USER_NAME}: Name a color. {AI_NAME}: Blue -{USER_NAME}:""" +{USER_NAME}: """ + params = gpt_params_parse() + if params.prompt is None and params.file is None: + params.prompt = prompt with LLaMAInteract(params) as m: m.interact() diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 33234fb..65206bf 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.28" \ No newline at end of file +__version__ = "0.2.29" \ No newline at end of file diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 171f357..f7b6ba6 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -1,11 +1,15 @@ import os import sys +import sys, traceback + +# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor +outnull_file = open(os.devnull, "w") +errnull_file = open(os.devnull, "w") class suppress_stdout_stderr(object): # NOTE: these must be "saved" here to avoid exceptions when using # this context manager inside of a __del__ method - open = open sys = sys os = os @@ -21,9 +25,6 @@ class suppress_stdout_stderr(object): if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'): return self # Return the instance without making changes - self.outnull_file = self.open(self.os.devnull, "w") - self.errnull_file = self.open(self.os.devnull, "w") - self.old_stdout_fileno_undup = self.sys.stdout.fileno() self.old_stderr_fileno_undup = self.sys.stderr.fileno() @@ -33,11 +34,11 @@ class suppress_stdout_stderr(object): self.old_stdout = self.sys.stdout self.old_stderr = self.sys.stderr - self.os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) - self.os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup) + self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup) - self.sys.stdout = self.outnull_file - self.sys.stderr = self.errnull_file + self.sys.stdout = outnull_file + self.sys.stderr = errnull_file return self def __exit__(self, *_): @@ -54,6 +55,3 @@ class suppress_stdout_stderr(object): self.os.close(self.old_stdout_fileno) self.os.close(self.old_stderr_fileno) - - self.outnull_file.close() - self.errnull_file.close() diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7c819b0..e4be9d1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -730,11 +730,13 @@ class Llama: *, # Model Params n_gpu_layers: int = 0, + split_mode: int = llama_cpp.LLAMA_SPLIT_LAYER, main_gpu: int = 0, tensor_split: Optional[List[float]] = None, vocab_only: bool = False, use_mmap: bool = True, use_mlock: bool = False, + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None, # Context Params seed: int = llama_cpp.LLAMA_DEFAULT_SEED, n_ctx: int = 512, @@ -798,11 +800,13 @@ class Llama: Args: model_path: Path to the model. n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded. - main_gpu: The GPU that is used for scratch and small tensors. + split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options. + main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split. vocab_only: Only load the vocabulary no weights. use_mmap: Use mmap if possible. use_mlock: Force the system to keep the model in RAM. + kv_overrides: Key-value overrides for the model. seed: RNG seed, -1 for random n_ctx: Text context, 0 = from model n_batch: Prompt processing maximum batch size @@ -848,6 +852,7 @@ class Llama: self.model_params.n_gpu_layers = ( 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers + self.model_params.split_mode = split_mode self.model_params.main_gpu = main_gpu self.tensor_split = tensor_split self._c_tensor_split = None @@ -866,6 +871,34 @@ class Llama: self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mlock = use_mlock + self.kv_overrides = kv_overrides + if kv_overrides is not None: + n_overrides = len(kv_overrides) + self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1) + self._kv_overrides_array_keys = [] + + for k, v in kv_overrides.items(): + key_buf = ctypes.create_string_buffer(k.encode("utf-8")) + self._kv_overrides_array_keys.append(key_buf) + self._kv_overrides_array[i].key = key_buf + if isinstance(v, int): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT + self._kv_overrides_array[i].value.int_value = v + elif isinstance(v, float): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT + self._kv_overrides_array[i].value.float_value = v + elif isinstance(v, bool): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL + self._kv_overrides_array[i].value.bool_value = v + else: + raise ValueError(f"Unknown value type for {k}: {v}") + + self._kv_overrides_array_sentinel_key = b'\0' + + # null array sentinel + self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key + self.model_params.kv_overrides = self._kv_overrides_array + self.n_batch = min(n_ctx, n_batch) # ??? self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads_batch = n_threads_batch or max( @@ -2143,11 +2176,13 @@ class Llama: model_path=self.model_path, # Model Params n_gpu_layers=self.model_params.n_gpu_layers, + split_mode=self.model_params.split_mode, main_gpu=self.model_params.main_gpu, tensor_split=self.tensor_split, vocab_only=self.model_params.vocab_only, use_mmap=self.model_params.use_mmap, use_mlock=self.model_params.use_mlock, + kv_overrides=self.kv_overrides, # Context Params seed=self.context_params.seed, n_ctx=self.context_params.n_ctx, @@ -2185,11 +2220,13 @@ class Llama: model_path=state["model_path"], # Model Params n_gpu_layers=state["n_gpu_layers"], + split_mode=state["split_mode"], main_gpu=state["main_gpu"], tensor_split=state["tensor_split"], vocab_only=state["vocab_only"], use_mmap=state["use_mmap"], use_mlock=state["use_mlock"], + kv_overrides=state["kv_overrides"], # Context Params seed=state["seed"], n_ctx=state["n_ctx"], diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 989b67a..9e8e3ce 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -229,6 +229,7 @@ LLAMA_SPLIT_NONE = 0 LLAMA_SPLIT_LAYER = 1 LLAMA_SPLIT_ROW = 2 + # typedef struct llama_token_data { # llama_token id; // token id # float logit; // log-odds of the token @@ -395,6 +396,7 @@ class llama_model_kv_override(Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; + # // Keep the booleans together to avoid misalignment during copy-by-value. # bool vocab_only; // only load the vocabulary, no weights # bool use_mmap; // use mmap if possible @@ -407,7 +409,7 @@ class llama_model_params(Structure): n_gpu_layers (int): number of layers to store in VRAM split_mode (int): how to split the model across multiple GPUs main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored - tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES + tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data @@ -526,6 +528,7 @@ It might not exist for progress report where '.' is output repeatedly.""" # bool quantize_output_tensor; // quantize output.weight # bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored # bool pure; // disable k-quant mixtures and quantize all tensors to the same type +# void * imatrix; // pointer to importance matrix data # } llama_model_quantize_params; class llama_model_quantize_params(Structure): """Parameters for llama_model_quantize @@ -537,6 +540,7 @@ class llama_model_quantize_params(Structure): quantize_output_tensor (bool): quantize output.weight only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored pure (bool): disable k-quant mixtures and quantize all tensors to the same type + imatrix (ctypes.c_void_p): pointer to importance matrix data """ _fields_ = [ @@ -545,6 +549,8 @@ class llama_model_quantize_params(Structure): ("allow_requantize", c_bool), ("quantize_output_tensor", c_bool), ("only_copy", c_bool), + ("pure", c_bool), + ("imatrix", c_void_p), ] @@ -1956,14 +1962,39 @@ _lib.llama_sample_repetition_penalties.restype = None # /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 -# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. -# /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. -# /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. -# LLAMA_API void llama_sample_classifier_free_guidance( -# struct llama_context * ctx, +# /// @param logits Logits extracted from the original generation context. +# /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. +# /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. +# LLAMA_API void llama_sample_apply_guidance( +# struct llama_context * ctx, +# float * logits, +# float * logits_guidance, +# float scale); +def llama_sample_apply_guidance( + ctx: llama_context_p, + logits, # type: _Pointer[c_float] + logits_guidance, # type: _Pointer[c_float] + scale: Union[c_float, float], +): + """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" + return _lib.llama_sample_apply_guidance(ctx, logits, logits_guidance, scale) + + +_lib.llama_sample_apply_guidance.argtypes = [ + llama_context_p, + c_float_p, + c_float_p, + c_float, +] +_lib.llama_sample_apply_guidance.restype = None + + +# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance( +# struct llama_context * ctx, # llama_token_data_array * candidates, -# struct llama_context * guidance_ctx, -# float scale); +# struct llama_context * guidance_ctx, +# float scale), +# "use llama_sample_apply_guidance() instead"); def llama_sample_classifier_free_guidance( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 0c3b2e0..c02e656 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1433,7 +1433,6 @@ class SchemaConverter: def visit(self, schema: Dict[str, Any], name: str) -> str: schema_type: Optional[str] = schema.get("type") # type: ignore - assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" rule_name = name or "root" if "$defs" in schema: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index c54e4eb..fed0a6d 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -197,7 +197,36 @@ async def authenticate( @router.post( - "/v1/completions", summary="Completion", dependencies=[Depends(authenticate)] + "/v1/completions", + summary="Completion", + dependencies=[Depends(authenticate)], + response_model= Union[ + llama_cpp.CreateCompletionResponse, + str, + ], + responses={ + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + {"$ref": "#/components/schemas/CreateCompletionResponse"} + ], + "title": "Completion response, when stream=False", + } + }, + "text/event-stream":{ + "schema": { + "type": "string", + "title": "Server Side Streaming response, when stream=True. " + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" + } + } + }, + } + }, ) @router.post( "/v1/engines/copilot-codex/completions", @@ -280,7 +309,33 @@ async def create_embedding( @router.post( - "/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)] + "/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)], + response_model= Union[ + llama_cpp.ChatCompletion, str + ], + responses={ + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + {"$ref": "#/components/schemas/CreateChatCompletionResponse"} + ], + "title": "Completion response, when stream=False", + } + }, + "text/event-stream":{ + "schema": { + "type": "string", + "title": "Server Side Streaming response, when stream=True" + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" + } + } + }, + } + }, ) async def create_chat_completion( request: Request, diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index b9373b7..f9be323 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Union, List +from typing import Dict, Optional, Union, List import llama_cpp @@ -71,6 +71,23 @@ class LlamaProxy: chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None + if settings.kv_overrides is not None: + assert isinstance(settings.kv_overrides, list) + kv_overrides = {} + for kv in settings.kv_overrides: + key, value = kv.split("=") + if ":" in value: + value_type, value = value.split(":") + if value_type == "bool": + kv_overrides[key] = value.lower() in ["true", "1"] + elif value_type == "int": + kv_overrides[key] = int(value) + elif value_type == "float": + kv_overrides[key] = float(value) + else: + raise ValueError(f"Unknown value type {value_type}") _model = llama_cpp.Llama( model_path=settings.model, @@ -81,6 +98,7 @@ class LlamaProxy: vocab_only=settings.vocab_only, use_mmap=settings.use_mmap, use_mlock=settings.use_mlock, + kv_overrides=kv_overrides, # Context Params seed=settings.seed, n_ctx=settings.n_ctx, diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 346b463..a10390c 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -28,6 +28,10 @@ class ModelSettings(BaseSettings): ge=-1, description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.", ) + split_mode: int = Field( + default=llama_cpp.LLAMA_SPLIT_LAYER, + description="The split mode to use.", + ) main_gpu: int = Field( default=0, ge=0, @@ -48,11 +52,15 @@ class ModelSettings(BaseSettings): default=llama_cpp.llama_mlock_supported(), description="Use mlock.", ) + kv_overrides: Optional[List[str]] = Field( + default=None, + description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.", + ) # Context Params seed: int = Field( default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random." ) - n_ctx: int = Field(default=2048, ge=1, description="The context size.") + n_ctx: int = Field(default=2048, ge=0, description="The context size.") n_batch: int = Field( default=512, ge=1, description="The batch size to use per eval." ) diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 2e24903..ef9392b 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -1,4 +1,5 @@ import llama_cpp +import json tree = """ leaf ::= "." @@ -6,8 +7,46 @@ node ::= leaf | "(" node node ")" root ::= node """ + def test_grammar_from_string(): grammar = llama_cpp.LlamaGrammar.from_string(tree) assert grammar._n_rules == 3 assert grammar._start_rule_index == 2 assert grammar.grammar is not None + + +def test_composed_pydantic_grammar(): + """ + from pydantic import BaseModel + + class A(BaseModel): + a: int + + class B(BaseModel): + a: A + b: int + """ + + # This schema corresponds to the grammar in the comment above. + # We don't use the pydantic models directly to avoid the dependency. + schema = { + "$defs": { + "A": { + "properties": {"a": {"title": "A", "type": "integer"}}, + "required": ["a"], + "title": "A", + "type": "object", + } + }, + "properties": { + "a": {"$ref": "#/$defs/A"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a", "b"], + "title": "B", + "type": "object", + } + + grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) + + assert grammar.grammar is not None diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 76484fb..5c99960 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 76484fbfd355df388f71d6edaa98e1692a74de7e +Subproject commit 5c999609013a30c06e6fd28be8db5c2074bcc196