From 12b7f2f4e9c68efa5555bacfaa49e44eb08efe60 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:51:25 +0100 Subject: [PATCH] [Feat] Multi model support (#931) * Update Llama class to handle chat_format & caching * Add settings.py * Add util.py & update __main__.py * multimodel * update settings.py * cleanup * delete util.py * Fix /v1/models endpoint * MultiLlama now iterable, app check-alive on "/" * instant model init if file is given * backward compability * revert model param mandatory * fix error * handle individual model config json * refactor * revert chathandler/clip_model changes * handle chat_handler in MulitLlama() * split settings into server/llama * reduce global vars * Update LlamaProxy to handle config files * Add free method to LlamaProxy * update arg parsers & install server alias * refactor cache settings * change server executable name * better var name * whitespace * Revert "whitespace" This reverts commit bc5cf51c64a95bfc9926e1bc58166059711a1cd8. * remove exe_name * Fix merge bugs * Fix type annotations * Fix type annotations * Fix uvicorn app factory * Fix settings * Refactor server * Remove formatting fix * Format * Use default model if not found in model settings * Fix * Cleanup * Fix * Fix * Remove unnused CommandLineSettings * Cleanup * Support default name for copilot-codex models --------- Co-authored-by: Andrei Betlen --- llama_cpp/server/__main__.py | 123 +++-- llama_cpp/server/app.py | 854 ++++++----------------------------- llama_cpp/server/cli.py | 97 ++++ llama_cpp/server/errors.py | 210 +++++++++ llama_cpp/server/model.py | 126 ++++++ llama_cpp/server/settings.py | 161 +++++++ llama_cpp/server/types.py | 264 +++++++++++ 7 files changed, 1042 insertions(+), 793 deletions(-) create mode 100644 llama_cpp/server/cli.py create mode 100644 llama_cpp/server/errors.py create mode 100644 llama_cpp/server/model.py create mode 100644 llama_cpp/server/settings.py create mode 100644 llama_cpp/server/types.py diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 45fc5a8..fadfc5f 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -9,7 +9,7 @@ export MODEL=../models/7B/... Then run: ``` -uvicorn llama_cpp.server.app:app --reload +uvicorn llama_cpp.server.app:create_app --reload ``` or @@ -21,81 +21,68 @@ python3 -m llama_cpp.server Then visit http://localhost:8000/docs to see the interactive API docs. """ +from __future__ import annotations + import os +import sys import argparse -from typing import List, Literal, Union import uvicorn -from llama_cpp.server.app import create_app, Settings +from llama_cpp.server.app import create_app +from llama_cpp.server.settings import ( + Settings, + ServerSettings, + ModelSettings, + ConfigFileSettings, +) +from llama_cpp.server.cli import add_args_from_model, parse_model_from_args -def get_base_type(annotation): - if getattr(annotation, '__origin__', None) is Literal: - return type(annotation.__args__[0]) - elif getattr(annotation, '__origin__', None) is Union: - non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)] - if non_optional_args: - return get_base_type(non_optional_args[0]) - elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List: - return get_base_type(annotation.__args__[0]) - else: - return annotation -def contains_list_type(annotation) -> bool: - origin = getattr(annotation, '__origin__', None) - - if origin is list or origin is List: - return True - elif origin in (Literal, Union): - return any(contains_list_type(arg) for arg in annotation.__args__) - else: - return False +def main(): + description = "🦙 Llama.cpp python server. Host your own LLMs!🚀" + parser = argparse.ArgumentParser(description=description) -def parse_bool_arg(arg): - if isinstance(arg, bytes): - arg = arg.decode('utf-8') + add_args_from_model(parser, Settings) + parser.add_argument( + "--config_file", + type=str, + help="Path to a config file to load.", + ) + server_settings: ServerSettings | None = None + model_settings: list[ModelSettings] = [] + args = parser.parse_args() + try: + # Load server settings from config_file if provided + config_file = os.environ.get("CONFIG_FILE", args.config_file) + if config_file: + if not os.path.exists(config_file): + raise ValueError(f"Config file {config_file} not found!") + with open(config_file, "rb") as f: + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) + server_settings = ServerSettings.model_validate(config_file_settings) + model_settings = config_file_settings.models + else: + server_settings = parse_model_from_args(ServerSettings, args) + model_settings = [parse_model_from_args(ModelSettings, args)] + except Exception as e: + print(e, file=sys.stderr) + parser.print_help() + sys.exit(1) + assert server_settings is not None + assert model_settings is not None + app = create_app( + server_settings=server_settings, + model_settings=model_settings, + ) + uvicorn.run( + app, + host=os.getenv("HOST", server_settings.host), + port=int(os.getenv("PORT", server_settings.port)), + ssl_keyfile=server_settings.ssl_keyfile, + ssl_certfile=server_settings.ssl_certfile, + ) - true_values = {'1', 'on', 't', 'true', 'y', 'yes'} - false_values = {'0', 'off', 'f', 'false', 'n', 'no'} - - arg_str = str(arg).lower().strip() - - if arg_str in true_values: - return True - elif arg_str in false_values: - return False - else: - raise ValueError(f'Invalid boolean argument: {arg}') if __name__ == "__main__": - parser = argparse.ArgumentParser() - for name, field in Settings.model_fields.items(): - description = field.description - if field.default is not None and description is not None: - description += f" (default: {field.default})" - base_type = get_base_type(field.annotation) if field.annotation is not None else str - list_type = contains_list_type(field.annotation) - if base_type is not bool: - parser.add_argument( - f"--{name}", - dest=name, - nargs="*" if list_type else None, - type=base_type, - help=description, - ) - if base_type is bool: - parser.add_argument( - f"--{name}", - dest=name, - type=parse_bool_arg, - help=f"{description}", - ) - - args = parser.parse_args() - settings = Settings(**{k: v for k, v in vars(args).items() if v is not None}) - app = create_app(settings=settings) - - uvicorn.run( - app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)), - ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile - ) + main() diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index db9705f..c54e4eb 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,375 +1,120 @@ -import sys +from __future__ import annotations + +import os import json -import traceback -import multiprocessing -import time -from re import compile, Match, Pattern + from threading import Lock from functools import partial -from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict -from typing_extensions import TypedDict, Literal +from typing import Iterator, List, Optional, Union, Dict import llama_cpp import anyio from anyio.streams.memory import MemoryObjectSendStream from starlette.concurrency import run_in_threadpool, iterate_in_threadpool -from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status +from fastapi import ( + Depends, + FastAPI, + APIRouter, + Request, + HTTPException, + status, +) from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from fastapi.routing import APIRoute from fastapi.security import HTTPBearer -from pydantic import BaseModel, Field -from pydantic_settings import BaseSettings from sse_starlette.sse import EventSourceResponse -from starlette_context import plugins +from starlette_context.plugins import RequestIdPlugin # type: ignore from starlette_context.middleware import RawContextMiddleware -import numpy as np -import numpy.typing as npt - - -# Disable warning for model and model_alias settings -BaseSettings.model_config["protected_namespaces"] = () - - -class Settings(BaseSettings): - model: str = Field( - description="The path to the model to use for generating completions." - ) - model_alias: Optional[str] = Field( - default=None, - description="The alias of the model to use for generating completions.", - ) - # Model Params - n_gpu_layers: int = Field( - default=0, - ge=-1, - description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.", - ) - main_gpu: int = Field( - default=0, - ge=0, - description="Main GPU to use.", - ) - tensor_split: Optional[List[float]] = Field( - default=None, - description="Split layers across multiple GPUs in proportion.", - ) - vocab_only: bool = Field( - default=False, description="Whether to only return the vocabulary." - ) - use_mmap: bool = Field( - default=llama_cpp.llama_mmap_supported(), - description="Use mmap.", - ) - use_mlock: bool = Field( - default=llama_cpp.llama_mlock_supported(), - description="Use mlock.", - ) - # Context Params - seed: int = Field( - default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random." - ) - n_ctx: int = Field(default=2048, ge=1, description="The context size.") - n_batch: int = Field( - default=512, ge=1, description="The batch size to use per eval." - ) - n_threads: int = Field( - default=max(multiprocessing.cpu_count() // 2, 1), - ge=1, - description="The number of threads to use.", - ) - n_threads_batch: int = Field( - default=max(multiprocessing.cpu_count() // 2, 1), - ge=0, - description="The number of threads to use when batch processing.", - ) - rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED) - rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") - rope_freq_scale: float = Field( - default=0.0, description="RoPE frequency scaling factor" - ) - yarn_ext_factor: float = Field(default=-1.0) - yarn_attn_factor: float = Field(default=1.0) - yarn_beta_fast: float = Field(default=32.0) - yarn_beta_slow: float = Field(default=1.0) - yarn_orig_ctx: int = Field(default=0) - mul_mat_q: bool = Field( - default=True, description="if true, use experimental mul_mat_q kernels" - ) - logits_all: bool = Field(default=True, description="Whether to return logits.") - embedding: bool = Field(default=True, description="Whether to use embeddings.") - offload_kqv: bool = Field( - default=False, description="Whether to offload kqv to the GPU." - ) - # Sampling Params - last_n_tokens_size: int = Field( - default=64, - ge=0, - description="Last n tokens to keep for repeat penalty calculation.", - ) - # LoRA Params - lora_base: Optional[str] = Field( - default=None, - description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.", - ) - lora_path: Optional[str] = Field( - default=None, - description="Path to a LoRA file to apply to the model.", - ) - # Backend Params - numa: bool = Field( - default=False, - description="Enable NUMA support.", - ) - # Chat Format Params - chat_format: str = Field( - default="llama-2", - description="Chat format to use.", - ) - clip_model_path: Optional[str] = Field( - default=None, - description="Path to a CLIP model to use for multi-modal chat completion.", - ) - # Cache Params - cache: bool = Field( - default=False, - description="Use a cache to reduce processing times for evaluated prompts.", - ) - cache_type: Literal["ram", "disk"] = Field( - default="ram", - description="The type of cache to use. Only used if cache is True.", - ) - cache_size: int = Field( - default=2 << 30, - description="The size of the cache in bytes. Only used if cache is True.", - ) - # Misc - verbose: bool = Field( - default=True, description="Whether to print debug information." - ) - # Server Params - host: str = Field(default="localhost", description="Listen address") - port: int = Field(default=8000, description="Listen port") - # SSL Params - ssl_keyfile: Optional[str] = Field( - default=None, description="SSL key file for HTTPS" - ) - ssl_certfile: Optional[str] = Field( - default=None, description="SSL certificate file for HTTPS" - ) - interrupt_requests: bool = Field( - default=True, - description="Whether to interrupt requests when a new request is received.", - ) - api_key: Optional[str] = Field( - default=None, - description="API key for authentication. If set all requests need to be authenticated." - ) - - -class ErrorResponse(TypedDict): - """OpenAI style error response""" - - message: str - type: str - param: Optional[str] - code: Optional[str] - - -class ErrorResponseFormatters: - """Collection of formatters for error responses. - - Args: - request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): - Request body - match (Match[str]): Match object from regex pattern - - Returns: - Tuple[int, ErrorResponse]: Status code and error response - """ - - @staticmethod - def context_length_exceeded( - request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], - match, # type: Match[str] # type: ignore - ) -> Tuple[int, ErrorResponse]: - """Formatter for context length exceeded error""" - - context_window = int(match.group(2)) - prompt_tokens = int(match.group(1)) - completion_tokens = request.max_tokens - if hasattr(request, "messages"): - # Chat completion - message = ( - "This model's maximum context length is {} tokens. " - "However, you requested {} tokens " - "({} in the messages, {} in the completion). " - "Please reduce the length of the messages or completion." - ) - else: - # Text completion - message = ( - "This model's maximum context length is {} tokens, " - "however you requested {} tokens " - "({} in your prompt; {} for the completion). " - "Please reduce your prompt; or completion length." - ) - return 400, ErrorResponse( - message=message.format( - context_window, - completion_tokens + prompt_tokens, - prompt_tokens, - completion_tokens, - ), - type="invalid_request_error", - param="messages", - code="context_length_exceeded", - ) - - @staticmethod - def model_not_found( - request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], - match, # type: Match[str] # type: ignore - ) -> Tuple[int, ErrorResponse]: - """Formatter for model_not_found error""" - - model_path = str(match.group(1)) - message = f"The model `{model_path}` does not exist" - return 400, ErrorResponse( - message=message, - type="invalid_request_error", - param=None, - code="model_not_found", - ) - - -class RouteErrorHandler(APIRoute): - """Custom APIRoute that handles application errors and exceptions""" - - # key: regex pattern for original error message from llama_cpp - # value: formatter function - pattern_and_formatters: Dict[ - "Pattern", - Callable[ - [ - Union["CreateCompletionRequest", "CreateChatCompletionRequest"], - "Match[str]", - ], - Tuple[int, ErrorResponse], - ], - ] = { - compile( - r"Requested tokens \((\d+)\) exceed context window of (\d+)" - ): ErrorResponseFormatters.context_length_exceeded, - compile( - r"Model path does not exist: (.+)" - ): ErrorResponseFormatters.model_not_found, - } - - def error_message_wrapper( - self, - error: Exception, - body: Optional[ - Union[ - "CreateChatCompletionRequest", - "CreateCompletionRequest", - "CreateEmbeddingRequest", - ] - ] = None, - ) -> Tuple[int, ErrorResponse]: - """Wraps error message in OpenAI style error response""" - print(f"Exception: {str(error)}", file=sys.stderr) - traceback.print_exc(file=sys.stderr) - if body is not None and isinstance( - body, - ( - CreateCompletionRequest, - CreateChatCompletionRequest, - ), - ): - # When text completion or chat completion - for pattern, callback in self.pattern_and_formatters.items(): - match = pattern.search(str(error)) - if match is not None: - return callback(body, match) - - # Wrap other errors as internal server error - return 500, ErrorResponse( - message=str(error), - type="internal_server_error", - param=None, - code=None, - ) - - def get_route_handler( - self, - ) -> Callable[[Request], Coroutine[None, None, Response]]: - """Defines custom route handler that catches exceptions and formats - in OpenAI style error response""" - - original_route_handler = super().get_route_handler() - - async def custom_route_handler(request: Request) -> Response: - try: - start_sec = time.perf_counter() - response = await original_route_handler(request) - elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) - response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" - return response - except HTTPException as unauthorized: - # api key check failed - raise unauthorized - except Exception as exc: - json_body = await request.json() - try: - if "messages" in json_body: - # Chat completion - body: Optional[ - Union[ - CreateChatCompletionRequest, - CreateCompletionRequest, - CreateEmbeddingRequest, - ] - ] = CreateChatCompletionRequest(**json_body) - elif "prompt" in json_body: - # Text completion - body = CreateCompletionRequest(**json_body) - else: - # Embedding - body = CreateEmbeddingRequest(**json_body) - except Exception: - # Invalid request body - body = None - - # Get proper error message from the exception - ( - status_code, - error_message, - ) = self.error_message_wrapper(error=exc, body=body) - return JSONResponse( - {"error": error_message}, - status_code=status_code, - ) - - return custom_route_handler +from llama_cpp.server.model import ( + LlamaProxy, +) +from llama_cpp.server.settings import ( + ConfigFileSettings, + Settings, + ModelSettings, + ServerSettings, +) +from llama_cpp.server.types import ( + CreateCompletionRequest, + CreateEmbeddingRequest, + CreateChatCompletionRequest, + ModelList, +) +from llama_cpp.server.errors import RouteErrorHandler router = APIRouter(route_class=RouteErrorHandler) -settings: Optional[Settings] = None -llama: Optional[llama_cpp.Llama] = None +_server_settings: Optional[ServerSettings] = None -def create_app(settings: Optional[Settings] = None): - if settings is None: - settings = Settings() +def set_server_settings(server_settings: ServerSettings): + global _server_settings + _server_settings = server_settings - middleware = [ - Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),)) - ] + +def get_server_settings(): + yield _server_settings + + +_llama_proxy: Optional[LlamaProxy] = None + +llama_outer_lock = Lock() +llama_inner_lock = Lock() + + +def set_llama_proxy(model_settings: List[ModelSettings]): + global _llama_proxy + _llama_proxy = LlamaProxy(models=model_settings) + + +def get_llama_proxy(): + # NOTE: This double lock allows the currently streaming llama model to + # check if any other requests are pending in the same thread and cancel + # the stream if so. + llama_outer_lock.acquire() + release_outer_lock = True + try: + llama_inner_lock.acquire() + try: + llama_outer_lock.release() + release_outer_lock = False + yield _llama_proxy + finally: + llama_inner_lock.release() + finally: + if release_outer_lock: + llama_outer_lock.release() + + +def create_app( + settings: Settings | None = None, + server_settings: ServerSettings | None = None, + model_settings: List[ModelSettings] | None = None, +): + config_file = os.environ.get("CONFIG_FILE", None) + if config_file is not None: + if not os.path.exists(config_file): + raise ValueError(f"Config file {config_file} not found!") + with open(config_file, "rb") as f: + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) + server_settings = ServerSettings.model_validate(config_file_settings) + model_settings = config_file_settings.models + + if server_settings is None and model_settings is None: + if settings is None: + settings = Settings() + server_settings = ServerSettings.model_validate(settings) + model_settings = [ModelSettings.model_validate(settings)] + + assert ( + server_settings is not None and model_settings is not None + ), "server_settings and model_settings must be provided together" + + set_server_settings(server_settings) + middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))] app = FastAPI( middleware=middleware, title="🦙 llama.cpp Python API", @@ -383,105 +128,13 @@ def create_app(settings: Optional[Settings] = None): allow_headers=["*"], ) app.include_router(router) - global llama - ## - chat_handler = None - if settings.chat_format == "llava-1-5": - assert settings.clip_model_path is not None - chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( - clip_model_path=settings.clip_model_path, verbose=settings.verbose - ) - ## + assert model_settings is not None + set_llama_proxy(model_settings=model_settings) - llama = llama_cpp.Llama( - model_path=settings.model, - # Model Params - n_gpu_layers=settings.n_gpu_layers, - main_gpu=settings.main_gpu, - tensor_split=settings.tensor_split, - vocab_only=settings.vocab_only, - use_mmap=settings.use_mmap, - use_mlock=settings.use_mlock, - # Context Params - seed=settings.seed, - n_ctx=settings.n_ctx, - n_batch=settings.n_batch, - n_threads=settings.n_threads, - n_threads_batch=settings.n_threads_batch, - rope_scaling_type=settings.rope_scaling_type, - rope_freq_base=settings.rope_freq_base, - rope_freq_scale=settings.rope_freq_scale, - yarn_ext_factor=settings.yarn_ext_factor, - yarn_attn_factor=settings.yarn_attn_factor, - yarn_beta_fast=settings.yarn_beta_fast, - yarn_beta_slow=settings.yarn_beta_slow, - yarn_orig_ctx=settings.yarn_orig_ctx, - mul_mat_q=settings.mul_mat_q, - logits_all=settings.logits_all, - embedding=settings.embedding, - offload_kqv=settings.offload_kqv, - # Sampling Params - last_n_tokens_size=settings.last_n_tokens_size, - # LoRA Params - lora_base=settings.lora_base, - lora_path=settings.lora_path, - # Backend Params - numa=settings.numa, - # Chat Format Params - chat_format=settings.chat_format, - chat_handler=chat_handler, - # Misc - verbose=settings.verbose, - ) - if settings.cache: - if settings.cache_type == "disk": - if settings.verbose: - print(f"Using disk cache with size {settings.cache_size}") - cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size) - else: - if settings.verbose: - print(f"Using ram cache with size {settings.cache_size}") - cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) - - cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size) - llama.set_cache(cache) - - def set_settings(_settings: Settings): - global settings - settings = _settings - - set_settings(settings) return app -llama_outer_lock = Lock() -llama_inner_lock = Lock() - - -def get_llama(): - # NOTE: This double lock allows the currently streaming llama model to - # check if any other requests are pending in the same thread and cancel - # the stream if so. - llama_outer_lock.acquire() - release_outer_lock = True - try: - llama_inner_lock.acquire() - try: - llama_outer_lock.release() - release_outer_lock = False - yield llama - finally: - llama_inner_lock.release() - finally: - if release_outer_lock: - llama_outer_lock.release() - - -def get_settings(): - yield settings - - async def get_event_publisher( request: Request, inner_send_chan: MemoryObjectSendStream, @@ -493,7 +146,10 @@ async def get_event_publisher( await inner_send_chan.send(dict(data=json.dumps(chunk))) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() - if settings.interrupt_requests and llama_outer_lock.locked(): + if ( + next(get_server_settings()).interrupt_requests + and llama_outer_lock.locked() + ): await inner_send_chan.send(dict(data="[DONE]")) raise anyio.get_cancelled_exc_class()() await inner_send_chan.send(dict(data="[DONE]")) @@ -504,156 +160,6 @@ async def get_event_publisher( raise e -model_field = Field( - description="The model to use for generating completions.", default=None -) - -max_tokens_field = Field( - default=16, ge=1, description="The maximum number of tokens to generate." -) - -temperature_field = Field( - default=0.8, - ge=0.0, - le=2.0, - description="Adjust the randomness of the generated text.\n\n" - + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.", -) - -top_p_field = Field( - default=0.95, - ge=0.0, - le=1.0, - description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" - + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", -) - -min_p_field = Field( - default=0.05, - ge=0.0, - le=1.0, - description="Sets a minimum base probability threshold for token selection.\n\n" - + "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.", -) - -stop_field = Field( - default=None, - description="A list of tokens at which to stop generation. If None, no stop tokens are used.", -) - -stream_field = Field( - default=False, - description="Whether to stream the results as they are generated. Useful for chatbots.", -) - -top_k_field = Field( - default=40, - ge=0, - description="Limit the next token selection to the K most probable tokens.\n\n" - + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.", -) - -repeat_penalty_field = Field( - default=1.1, - ge=0.0, - description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" - + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", -) - -presence_penalty_field = Field( - default=0.0, - ge=-2.0, - le=2.0, - description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", -) - -frequency_penalty_field = Field( - default=0.0, - ge=-2.0, - le=2.0, - description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", -) - -mirostat_mode_field = Field( - default=0, - ge=0, - le=2, - description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)", -) - -mirostat_tau_field = Field( - default=5.0, - ge=0.0, - le=10.0, - description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text", -) - -mirostat_eta_field = Field( - default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate" -) - -grammar = Field( - default=None, - description="A CBNF grammar (as string) to be used for formatting the model's output.", -) - - -class CreateCompletionRequest(BaseModel): - prompt: Union[str, List[str]] = Field( - default="", description="The prompt to generate completions for." - ) - suffix: Optional[str] = Field( - default=None, - description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", - ) - max_tokens: int = max_tokens_field - temperature: float = temperature_field - top_p: float = top_p_field - min_p: float = min_p_field - echo: bool = Field( - default=False, - description="Whether to echo the prompt in the generated text. Useful for chatbots.", - ) - stop: Optional[Union[str, List[str]]] = stop_field - stream: bool = stream_field - logprobs: Optional[int] = Field( - default=None, - ge=0, - description="The number of logprobs to generate. If None, no logprobs are generated.", - ) - presence_penalty: Optional[float] = presence_penalty_field - frequency_penalty: Optional[float] = frequency_penalty_field - logit_bias: Optional[Dict[str, float]] = Field(None) - logprobs: Optional[int] = Field(None) - seed: Optional[int] = Field(None) - - # ignored or currently unsupported - model: Optional[str] = model_field - n: Optional[int] = 1 - best_of: Optional[int] = 1 - user: Optional[str] = Field(default=None) - - # llama.cpp specific parameters - top_k: int = top_k_field - repeat_penalty: float = repeat_penalty_field - logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) - mirostat_mode: int = mirostat_mode_field - mirostat_tau: float = mirostat_tau_field - mirostat_eta: float = mirostat_eta_field - grammar: Optional[str] = None - - model_config = { - "json_schema_extra": { - "examples": [ - { - "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", - "stop": ["\n", "###"], - } - ] - } - } - - def _logit_bias_tokens_to_input_ids( llama: llama_cpp.Llama, logit_bias: Dict[str, float], @@ -670,7 +176,10 @@ def _logit_bias_tokens_to_input_ids( bearer_scheme = HTTPBearer(auto_error=False) -async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)): +async def authenticate( + settings: Settings = Depends(get_server_settings), + authorization: Optional[str] = Depends(bearer_scheme), +): # Skip API key check if it's not set in settings if settings.api_key is None: return True @@ -688,20 +197,28 @@ async def authenticate(settings: Settings = Depends(get_settings), authorization @router.post( - "/v1/completions", - summary="Completion" + "/v1/completions", summary="Completion", dependencies=[Depends(authenticate)] +) +@router.post( + "/v1/engines/copilot-codex/completions", + include_in_schema=False, + dependencies=[Depends(authenticate)], ) -@router.post("/v1/engines/copilot-codex/completions", include_in_schema=False) async def create_completion( request: Request, body: CreateCompletionRequest, - llama: llama_cpp.Llama = Depends(get_llama), - authenticated: str = Depends(authenticate), + llama_proxy: LlamaProxy = Depends(get_llama_proxy), ) -> llama_cpp.Completion: if isinstance(body.prompt, list): assert len(body.prompt) <= 1 body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" + llama = llama_proxy( + body.model + if request.url.path != "/v1/engines/copilot-codex/completions" + else "copilot-codex" + ) + exclude = { "n", "best_of", @@ -749,124 +266,26 @@ async def create_completion( return iterator_or_completion -class CreateEmbeddingRequest(BaseModel): - model: Optional[str] = model_field - input: Union[str, List[str]] = Field(description="The input to embed.") - user: Optional[str] = Field(default=None) - - model_config = { - "json_schema_extra": { - "examples": [ - { - "input": "The food was delicious and the waiter...", - } - ] - } - } - - @router.post( - "/v1/embeddings", - summary="Embedding" + "/v1/embeddings", summary="Embedding", dependencies=[Depends(authenticate)] ) async def create_embedding( request: CreateEmbeddingRequest, - llama: llama_cpp.Llama = Depends(get_llama), - authenticated: str = Depends(authenticate), + llama_proxy: LlamaProxy = Depends(get_llama_proxy), ): return await run_in_threadpool( - llama.create_embedding, **request.model_dump(exclude={"user"}) + llama_proxy(request.model).create_embedding, + **request.model_dump(exclude={"user"}), ) -class ChatCompletionRequestMessage(BaseModel): - role: Literal["system", "user", "assistant", "function"] = Field( - default="user", description="The role of the message." - ) - content: Optional[str] = Field( - default="", description="The content of the message." - ) - - -class CreateChatCompletionRequest(BaseModel): - messages: List[llama_cpp.ChatCompletionRequestMessage] = Field( - default=[], description="A list of messages to generate completions for." - ) - functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( - default=None, - description="A list of functions to apply to the generated completions.", - ) - function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field( - default=None, - description="A function to apply to the generated completions.", - ) - tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field( - default=None, - description="A list of tools to apply to the generated completions.", - ) - tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field( - default=None, - description="A tool to apply to the generated completions.", - ) # TODO: verify - max_tokens: Optional[int] = Field( - default=None, - description="The maximum number of tokens to generate. Defaults to inf", - ) - temperature: float = temperature_field - top_p: float = top_p_field - min_p: float = min_p_field - stop: Optional[Union[str, List[str]]] = stop_field - stream: bool = stream_field - presence_penalty: Optional[float] = presence_penalty_field - frequency_penalty: Optional[float] = frequency_penalty_field - logit_bias: Optional[Dict[str, float]] = Field(None) - seed: Optional[int] = Field(None) - response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field( - default=None, - ) - - # ignored or currently unsupported - model: Optional[str] = model_field - n: Optional[int] = 1 - user: Optional[str] = Field(None) - - # llama.cpp specific parameters - top_k: int = top_k_field - repeat_penalty: float = repeat_penalty_field - logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) - mirostat_mode: int = mirostat_mode_field - mirostat_tau: float = mirostat_tau_field - mirostat_eta: float = mirostat_eta_field - grammar: Optional[str] = None - - model_config = { - "json_schema_extra": { - "examples": [ - { - "messages": [ - ChatCompletionRequestMessage( - role="system", content="You are a helpful assistant." - ).model_dump(), - ChatCompletionRequestMessage( - role="user", content="What is the capital of France?" - ).model_dump(), - ] - } - ] - } - } - - @router.post( - "/v1/chat/completions", - summary="Chat" + "/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)] ) async def create_chat_completion( request: Request, body: CreateChatCompletionRequest, - llama: llama_cpp.Llama = Depends(get_llama), - settings: Settings = Depends(get_settings), - authenticated: str = Depends(authenticate), + llama_proxy: LlamaProxy = Depends(get_llama_proxy), ) -> llama_cpp.ChatCompletion: exclude = { "n", @@ -874,7 +293,7 @@ async def create_chat_completion( "user", } kwargs = body.model_dump(exclude=exclude) - + llama = llama_proxy(body.model) if body.logit_bias is not None: kwargs["logit_bias"] = ( _logit_bias_tokens_to_input_ids(llama, body.logit_bias) @@ -913,34 +332,19 @@ async def create_chat_completion( return iterator_or_completion -class ModelData(TypedDict): - id: str - object: Literal["model"] - owned_by: str - permissions: List[str] - - -class ModelList(TypedDict): - object: Literal["list"] - data: List[ModelData] - - -@router.get("/v1/models", summary="Models") +@router.get("/v1/models", summary="Models", dependencies=[Depends(authenticate)]) async def get_models( - settings: Settings = Depends(get_settings), - authenticated: str = Depends(authenticate), + llama_proxy: LlamaProxy = Depends(get_llama_proxy), ) -> ModelList: - assert llama is not None return { "object": "list", "data": [ { - "id": settings.model_alias - if settings.model_alias is not None - else llama.model_path, + "id": model_alias, "object": "model", "owned_by": "me", "permissions": [], } + for model_alias in llama_proxy ], } diff --git a/llama_cpp/server/cli.py b/llama_cpp/server/cli.py new file mode 100644 index 0000000..8e32d2c --- /dev/null +++ b/llama_cpp/server/cli.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import argparse + +from typing import List, Literal, Union, Any, Type, TypeVar + +from pydantic import BaseModel + + +def _get_base_type(annotation: Type[Any]) -> Type[Any]: + if getattr(annotation, "__origin__", None) is Literal: + assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore + return type(annotation.__args__[0]) # type: ignore + elif getattr(annotation, "__origin__", None) is Union: + assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore + non_optional_args: List[Type[Any]] = [ + arg for arg in annotation.__args__ if arg is not type(None) # type: ignore + ] + if non_optional_args: + return _get_base_type(non_optional_args[0]) + elif ( + getattr(annotation, "__origin__", None) is list + or getattr(annotation, "__origin__", None) is List + ): + assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore + return _get_base_type(annotation.__args__[0]) # type: ignore + return annotation + + +def _contains_list_type(annotation: Type[Any] | None) -> bool: + origin = getattr(annotation, "__origin__", None) + + if origin is list or origin is List: + return True + elif origin in (Literal, Union): + return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore + else: + return False + + +def _parse_bool_arg(arg: str | bytes | bool) -> bool: + if isinstance(arg, bytes): + arg = arg.decode("utf-8") + + true_values = {"1", "on", "t", "true", "y", "yes"} + false_values = {"0", "off", "f", "false", "n", "no"} + + arg_str = str(arg).lower().strip() + + if arg_str in true_values: + return True + elif arg_str in false_values: + return False + else: + raise ValueError(f"Invalid boolean argument: {arg}") + + +def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]): + """Add arguments from a pydantic model to an argparse parser.""" + + for name, field in model.model_fields.items(): + description = field.description + if field.default and description and not field.is_required(): + description += f" (default: {field.default})" + base_type = ( + _get_base_type(field.annotation) if field.annotation is not None else str + ) + list_type = _contains_list_type(field.annotation) + if base_type is not bool: + parser.add_argument( + f"--{name}", + dest=name, + nargs="*" if list_type else None, + type=base_type, + help=description, + ) + if base_type is bool: + parser.add_argument( + f"--{name}", + dest=name, + type=_parse_bool_arg, + help=f"{description}", + ) + + +T = TypeVar("T", bound=type[BaseModel]) + + +def parse_model_from_args(model: T, args: argparse.Namespace) -> T: + """Parse a pydantic model from an argparse namespace.""" + return model( + **{ + k: v + for k, v in vars(args).items() + if v is not None and k in model.model_fields + } + ) diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py new file mode 100644 index 0000000..febe3e3 --- /dev/null +++ b/llama_cpp/server/errors.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import sys +import traceback +import time +from re import compile, Match, Pattern +from typing import Callable, Coroutine, Optional, Tuple, Union, Dict +from typing_extensions import TypedDict + + +from fastapi import ( + Request, + Response, + HTTPException, +) +from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute + +from llama_cpp.server.types import ( + CreateCompletionRequest, + CreateEmbeddingRequest, + CreateChatCompletionRequest, +) + +class ErrorResponse(TypedDict): + """OpenAI style error response""" + + message: str + type: str + param: Optional[str] + code: Optional[str] + + +class ErrorResponseFormatters: + """Collection of formatters for error responses. + + Args: + request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): + Request body + match (Match[str]): Match object from regex pattern + + Returns: + Tuple[int, ErrorResponse]: Status code and error response + """ + + @staticmethod + def context_length_exceeded( + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore + ) -> Tuple[int, ErrorResponse]: + """Formatter for context length exceeded error""" + + context_window = int(match.group(2)) + prompt_tokens = int(match.group(1)) + completion_tokens = request.max_tokens + if hasattr(request, "messages"): + # Chat completion + message = ( + "This model's maximum context length is {} tokens. " + "However, you requested {} tokens " + "({} in the messages, {} in the completion). " + "Please reduce the length of the messages or completion." + ) + else: + # Text completion + message = ( + "This model's maximum context length is {} tokens, " + "however you requested {} tokens " + "({} in your prompt; {} for the completion). " + "Please reduce your prompt; or completion length." + ) + return 400, ErrorResponse( + message=message.format( + context_window, + completion_tokens + prompt_tokens, + prompt_tokens, + completion_tokens, + ), # type: ignore + type="invalid_request_error", + param="messages", + code="context_length_exceeded", + ) + + @staticmethod + def model_not_found( + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore + ) -> Tuple[int, ErrorResponse]: + """Formatter for model_not_found error""" + + model_path = str(match.group(1)) + message = f"The model `{model_path}` does not exist" + return 400, ErrorResponse( + message=message, + type="invalid_request_error", + param=None, + code="model_not_found", + ) + + +class RouteErrorHandler(APIRoute): + """Custom APIRoute that handles application errors and exceptions""" + + # key: regex pattern for original error message from llama_cpp + # value: formatter function + pattern_and_formatters: Dict[ + "Pattern[str]", + Callable[ + [ + Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + "Match[str]", + ], + Tuple[int, ErrorResponse], + ], + ] = { + compile( + r"Requested tokens \((\d+)\) exceed context window of (\d+)" + ): ErrorResponseFormatters.context_length_exceeded, + compile( + r"Model path does not exist: (.+)" + ): ErrorResponseFormatters.model_not_found, + } + + def error_message_wrapper( + self, + error: Exception, + body: Optional[ + Union[ + "CreateChatCompletionRequest", + "CreateCompletionRequest", + "CreateEmbeddingRequest", + ] + ] = None, + ) -> Tuple[int, ErrorResponse]: + """Wraps error message in OpenAI style error response""" + print(f"Exception: {str(error)}", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + if body is not None and isinstance( + body, + ( + CreateCompletionRequest, + CreateChatCompletionRequest, + ), + ): + # When text completion or chat completion + for pattern, callback in self.pattern_and_formatters.items(): + match = pattern.search(str(error)) + if match is not None: + return callback(body, match) + + # Wrap other errors as internal server error + return 500, ErrorResponse( + message=str(error), + type="internal_server_error", + param=None, + code=None, + ) + + def get_route_handler( + self, + ) -> Callable[[Request], Coroutine[None, None, Response]]: + """Defines custom route handler that catches exceptions and formats + in OpenAI style error response""" + + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + try: + start_sec = time.perf_counter() + response = await original_route_handler(request) + elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) + response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" + return response + except HTTPException as unauthorized: + # api key check failed + raise unauthorized + except Exception as exc: + json_body = await request.json() + try: + if "messages" in json_body: + # Chat completion + body: Optional[ + Union[ + CreateChatCompletionRequest, + CreateCompletionRequest, + CreateEmbeddingRequest, + ] + ] = CreateChatCompletionRequest(**json_body) + elif "prompt" in json_body: + # Text completion + body = CreateCompletionRequest(**json_body) + else: + # Embedding + body = CreateEmbeddingRequest(**json_body) + except Exception: + # Invalid request body + body = None + + # Get proper error message from the exception + ( + status_code, + error_message, + ) = self.error_message_wrapper(error=exc, body=body) + return JSONResponse( + {"error": error_message}, + status_code=status_code, + ) + + return custom_route_handler + diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py new file mode 100644 index 0000000..b9373b7 --- /dev/null +++ b/llama_cpp/server/model.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import Optional, Union, List + +import llama_cpp + +from llama_cpp.server.settings import ModelSettings + + +class LlamaProxy: + def __init__(self, models: List[ModelSettings]) -> None: + assert len(models) > 0, "No models provided!" + + self._model_settings_dict: dict[str, ModelSettings] = {} + for model in models: + if not model.model_alias: + model.model_alias = model.model + self._model_settings_dict[model.model_alias] = model + + self._current_model: Optional[llama_cpp.Llama] = None + self._current_model_alias: Optional[str] = None + + self._default_model_settings: ModelSettings = models[0] + self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore + + # Load default model + self._current_model = self.load_llama_from_model_settings( + self._default_model_settings + ) + self._current_model_alias = self._default_model_alias + + def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama: + if model is None: + model = self._default_model_alias + + if model not in self._model_settings_dict: + model = self._default_model_alias + + if model == self._current_model_alias: + if self._current_model is not None: + return self._current_model + + self._current_model = None + + settings = self._model_settings_dict[model] + self._current_model = self.load_llama_from_model_settings(settings) + self._current_model_alias = model + return self._current_model + + def __getitem__(self, model: str): + return self._model_settings_dict[model].model_dump() + + def __setitem__(self, model: str, settings: Union[ModelSettings, str, bytes]): + if isinstance(settings, (bytes, str)): + settings = ModelSettings.model_validate_json(settings) + self._model_settings_dict[model] = settings + + def __iter__(self): + for model in self._model_settings_dict: + yield model + + def free(self): + if self._current_model: + del self._current_model + + @staticmethod + def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: + chat_handler = None + if settings.chat_format == "llava-1-5": + assert settings.clip_model_path is not None, "clip model not found" + chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( + clip_model_path=settings.clip_model_path, verbose=settings.verbose + ) + + _model = llama_cpp.Llama( + model_path=settings.model, + # Model Params + n_gpu_layers=settings.n_gpu_layers, + main_gpu=settings.main_gpu, + tensor_split=settings.tensor_split, + vocab_only=settings.vocab_only, + use_mmap=settings.use_mmap, + use_mlock=settings.use_mlock, + # Context Params + seed=settings.seed, + n_ctx=settings.n_ctx, + n_batch=settings.n_batch, + n_threads=settings.n_threads, + n_threads_batch=settings.n_threads_batch, + rope_scaling_type=settings.rope_scaling_type, + rope_freq_base=settings.rope_freq_base, + rope_freq_scale=settings.rope_freq_scale, + yarn_ext_factor=settings.yarn_ext_factor, + yarn_attn_factor=settings.yarn_attn_factor, + yarn_beta_fast=settings.yarn_beta_fast, + yarn_beta_slow=settings.yarn_beta_slow, + yarn_orig_ctx=settings.yarn_orig_ctx, + mul_mat_q=settings.mul_mat_q, + logits_all=settings.logits_all, + embedding=settings.embedding, + offload_kqv=settings.offload_kqv, + # Sampling Params + last_n_tokens_size=settings.last_n_tokens_size, + # LoRA Params + lora_base=settings.lora_base, + lora_path=settings.lora_path, + # Backend Params + numa=settings.numa, + # Chat Format Params + chat_format=settings.chat_format, + chat_handler=chat_handler, + # Misc + verbose=settings.verbose, + ) + if settings.cache: + if settings.cache_type == "disk": + if settings.verbose: + print(f"Using disk cache with size {settings.cache_size}") + cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size) + else: + if settings.verbose: + print(f"Using ram cache with size {settings.cache_size}") + cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) + _model.set_cache(cache) + return _model + diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py new file mode 100644 index 0000000..53ead74 --- /dev/null +++ b/llama_cpp/server/settings.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import multiprocessing + +from typing import Optional, List, Literal +from pydantic import Field +from pydantic_settings import BaseSettings + +import llama_cpp + +# Disable warning for model and model_alias settings +BaseSettings.model_config["protected_namespaces"] = () + + +class ModelSettings(BaseSettings): + model: str = Field( + description="The path to the model to use for generating completions." + ) + model_alias: Optional[str] = Field( + default=None, + description="The alias of the model to use for generating completions.", + ) + # Model Params + n_gpu_layers: int = Field( + default=0, + ge=-1, + description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.", + ) + main_gpu: int = Field( + default=0, + ge=0, + description="Main GPU to use.", + ) + tensor_split: Optional[List[float]] = Field( + default=None, + description="Split layers across multiple GPUs in proportion.", + ) + vocab_only: bool = Field( + default=False, description="Whether to only return the vocabulary." + ) + use_mmap: bool = Field( + default=llama_cpp.llama_mmap_supported(), + description="Use mmap.", + ) + use_mlock: bool = Field( + default=llama_cpp.llama_mlock_supported(), + description="Use mlock.", + ) + # Context Params + seed: int = Field( + default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random." + ) + n_ctx: int = Field(default=2048, ge=1, description="The context size.") + n_batch: int = Field( + default=512, ge=1, description="The batch size to use per eval." + ) + n_threads: int = Field( + default=max(multiprocessing.cpu_count() // 2, 1), + ge=1, + description="The number of threads to use.", + ) + n_threads_batch: int = Field( + default=max(multiprocessing.cpu_count() // 2, 1), + ge=0, + description="The number of threads to use when batch processing.", + ) + rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED) + rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") + rope_freq_scale: float = Field( + default=0.0, description="RoPE frequency scaling factor" + ) + yarn_ext_factor: float = Field(default=-1.0) + yarn_attn_factor: float = Field(default=1.0) + yarn_beta_fast: float = Field(default=32.0) + yarn_beta_slow: float = Field(default=1.0) + yarn_orig_ctx: int = Field(default=0) + mul_mat_q: bool = Field( + default=True, description="if true, use experimental mul_mat_q kernels" + ) + logits_all: bool = Field(default=True, description="Whether to return logits.") + embedding: bool = Field(default=True, description="Whether to use embeddings.") + offload_kqv: bool = Field( + default=False, description="Whether to offload kqv to the GPU." + ) + # Sampling Params + last_n_tokens_size: int = Field( + default=64, + ge=0, + description="Last n tokens to keep for repeat penalty calculation.", + ) + # LoRA Params + lora_base: Optional[str] = Field( + default=None, + description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.", + ) + lora_path: Optional[str] = Field( + default=None, + description="Path to a LoRA file to apply to the model.", + ) + # Backend Params + numa: bool = Field( + default=False, + description="Enable NUMA support.", + ) + # Chat Format Params + chat_format: str = Field( + default="llama-2", + description="Chat format to use.", + ) + clip_model_path: Optional[str] = Field( + default=None, + description="Path to a CLIP model to use for multi-modal chat completion.", + ) + # Cache Params + cache: bool = Field( + default=False, + description="Use a cache to reduce processing times for evaluated prompts.", + ) + cache_type: Literal["ram", "disk"] = Field( + default="ram", + description="The type of cache to use. Only used if cache is True.", + ) + cache_size: int = Field( + default=2 << 30, + description="The size of the cache in bytes. Only used if cache is True.", + ) + # Misc + verbose: bool = Field( + default=True, description="Whether to print debug information." + ) + + +class ServerSettings(BaseSettings): + # Uvicorn Settings + host: str = Field(default="localhost", description="Listen address") + port: int = Field(default=8000, description="Listen port") + ssl_keyfile: Optional[str] = Field( + default=None, description="SSL key file for HTTPS" + ) + ssl_certfile: Optional[str] = Field( + default=None, description="SSL certificate file for HTTPS" + ) + # FastAPI Settings + api_key: Optional[str] = Field( + default=None, + description="API key for authentication. If set all requests need to be authenticated.", + ) + interrupt_requests: bool = Field( + default=True, + description="Whether to interrupt requests when a new request is received.", + ) + + +class Settings(ServerSettings, ModelSettings): + pass + + +class ConfigFileSettings(ServerSettings): + models: List[ModelSettings] = Field( + default=[], description="Model configs, overwrites default config" + ) diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py new file mode 100644 index 0000000..f0867bc --- /dev/null +++ b/llama_cpp/server/types.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from typing import List, Optional, Union, Dict +from typing_extensions import TypedDict, Literal + +from pydantic import BaseModel, Field + +import llama_cpp + + +model_field = Field( + description="The model to use for generating completions.", default=None +) + +max_tokens_field = Field( + default=16, ge=1, description="The maximum number of tokens to generate." +) + +temperature_field = Field( + default=0.8, + ge=0.0, + le=2.0, + description="Adjust the randomness of the generated text.\n\n" + + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.", +) + +top_p_field = Field( + default=0.95, + ge=0.0, + le=1.0, + description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" + + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", +) + +min_p_field = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Sets a minimum base probability threshold for token selection.\n\n" + + "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.", +) + +stop_field = Field( + default=None, + description="A list of tokens at which to stop generation. If None, no stop tokens are used.", +) + +stream_field = Field( + default=False, + description="Whether to stream the results as they are generated. Useful for chatbots.", +) + +top_k_field = Field( + default=40, + ge=0, + description="Limit the next token selection to the K most probable tokens.\n\n" + + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.", +) + +repeat_penalty_field = Field( + default=1.1, + ge=0.0, + description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" + + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", +) + +presence_penalty_field = Field( + default=0.0, + ge=-2.0, + le=2.0, + description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", +) + +frequency_penalty_field = Field( + default=0.0, + ge=-2.0, + le=2.0, + description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", +) + +mirostat_mode_field = Field( + default=0, + ge=0, + le=2, + description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)", +) + +mirostat_tau_field = Field( + default=5.0, + ge=0.0, + le=10.0, + description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text", +) + +mirostat_eta_field = Field( + default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate" +) + +grammar = Field( + default=None, + description="A CBNF grammar (as string) to be used for formatting the model's output.", +) + + +class CreateCompletionRequest(BaseModel): + prompt: Union[str, List[str]] = Field( + default="", description="The prompt to generate completions for." + ) + suffix: Optional[str] = Field( + default=None, + description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", + ) + max_tokens: int = max_tokens_field + temperature: float = temperature_field + top_p: float = top_p_field + min_p: float = min_p_field + echo: bool = Field( + default=False, + description="Whether to echo the prompt in the generated text. Useful for chatbots.", + ) + stop: Optional[Union[str, List[str]]] = stop_field + stream: bool = stream_field + logprobs: Optional[int] = Field( + default=None, + ge=0, + description="The number of logprobs to generate. If None, no logprobs are generated.", + ) + presence_penalty: Optional[float] = presence_penalty_field + frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + logprobs: Optional[int] = Field(None) + seed: Optional[int] = Field(None) + + # ignored or currently unsupported + model: Optional[str] = model_field + n: Optional[int] = 1 + best_of: Optional[int] = 1 + user: Optional[str] = Field(default=None) + + # llama.cpp specific parameters + top_k: int = top_k_field + repeat_penalty: float = repeat_penalty_field + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None + + model_config = { + "json_schema_extra": { + "examples": [ + { + "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", + "stop": ["\n", "###"], + } + ] + } + } + + +class CreateEmbeddingRequest(BaseModel): + model: Optional[str] = model_field + input: Union[str, List[str]] = Field(description="The input to embed.") + user: Optional[str] = Field(default=None) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "input": "The food was delicious and the waiter...", + } + ] + } + } + + +class ChatCompletionRequestMessage(BaseModel): + role: Literal["system", "user", "assistant", "function"] = Field( + default="user", description="The role of the message." + ) + content: Optional[str] = Field( + default="", description="The content of the message." + ) + + +class CreateChatCompletionRequest(BaseModel): + messages: List[llama_cpp.ChatCompletionRequestMessage] = Field( + default=[], description="A list of messages to generate completions for." + ) + functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( + default=None, + description="A list of functions to apply to the generated completions.", + ) + function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field( + default=None, + description="A function to apply to the generated completions.", + ) + tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field( + default=None, + description="A list of tools to apply to the generated completions.", + ) + tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field( + default=None, + description="A tool to apply to the generated completions.", + ) # TODO: verify + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate. Defaults to inf", + ) + temperature: float = temperature_field + top_p: float = top_p_field + min_p: float = min_p_field + stop: Optional[Union[str, List[str]]] = stop_field + stream: bool = stream_field + presence_penalty: Optional[float] = presence_penalty_field + frequency_penalty: Optional[float] = frequency_penalty_field + logit_bias: Optional[Dict[str, float]] = Field(None) + seed: Optional[int] = Field(None) + response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field( + default=None, + ) + + # ignored or currently unsupported + model: Optional[str] = model_field + n: Optional[int] = 1 + user: Optional[str] = Field(None) + + # llama.cpp specific parameters + top_k: int = top_k_field + repeat_penalty: float = repeat_penalty_field + logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None + + model_config = { + "json_schema_extra": { + "examples": [ + { + "messages": [ + ChatCompletionRequestMessage( + role="system", content="You are a helpful assistant." + ).model_dump(), + ChatCompletionRequestMessage( + role="user", content="What is the capital of France?" + ).model_dump(), + ] + } + ] + } + } + + +class ModelData(TypedDict): + id: str + object: Literal["model"] + owned_by: str + permissions: List[str] + + +class ModelList(TypedDict): + object: Literal["list"] + data: List[ModelData]