[Feat] Multi model support (#931)

* Update Llama class to handle chat_format & caching

* Add settings.py

* Add util.py & update __main__.py

* multimodel

* update settings.py

* cleanup

* delete util.py

* Fix /v1/models endpoint

* MultiLlama now iterable, app check-alive on "/"

* instant model init if file is given

* backward compability

* revert model param mandatory

* fix error

* handle individual model config json

* refactor

* revert chathandler/clip_model changes

* handle chat_handler in MulitLlama()

* split settings into server/llama

* reduce global vars

* Update LlamaProxy to handle config files

* Add free method to LlamaProxy

* update arg parsers & install server alias

* refactor cache settings

* change server executable name

* better var name

* whitespace

* Revert "whitespace"

This reverts commit bc5cf51c64a95bfc9926e1bc58166059711a1cd8.

* remove exe_name

* Fix merge bugs

* Fix type annotations

* Fix type annotations

* Fix uvicorn app factory

* Fix settings

* Refactor server

* Remove formatting fix

* Format

* Use default model if not found in model settings

* Fix

* Cleanup

* Fix

* Fix

* Remove unnused CommandLineSettings

* Cleanup

* Support default name for copilot-codex models

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
Dave 2023-12-22 11:51:25 +01:00 committed by GitHub
parent 4a85442c35
commit 12b7f2f4e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 1042 additions and 793 deletions

View file

@ -9,7 +9,7 @@ export MODEL=../models/7B/...
Then run:
```
uvicorn llama_cpp.server.app:app --reload
uvicorn llama_cpp.server.app:create_app --reload
```
or
@ -21,81 +21,68 @@ python3 -m llama_cpp.server
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
from __future__ import annotations
import os
import sys
import argparse
from typing import List, Literal, Union
import uvicorn
from llama_cpp.server.app import create_app, Settings
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import (
Settings,
ServerSettings,
ModelSettings,
ConfigFileSettings,
)
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
def get_base_type(annotation):
if getattr(annotation, '__origin__', None) is Literal:
return type(annotation.__args__[0])
elif getattr(annotation, '__origin__', None) is Union:
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
if non_optional_args:
return get_base_type(non_optional_args[0])
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
return get_base_type(annotation.__args__[0])
def main():
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
parser = argparse.ArgumentParser(description=description)
add_args_from_model(parser, Settings)
parser.add_argument(
"--config_file",
type=str,
help="Path to a config file to load.",
)
server_settings: ServerSettings | None = None
model_settings: list[ModelSettings] = []
args = parser.parse_args()
try:
# Load server settings from config_file if provided
config_file = os.environ.get("CONFIG_FILE", args.config_file)
if config_file:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
else:
return annotation
server_settings = parse_model_from_args(ServerSettings, args)
model_settings = [parse_model_from_args(ModelSettings, args)]
except Exception as e:
print(e, file=sys.stderr)
parser.print_help()
sys.exit(1)
assert server_settings is not None
assert model_settings is not None
app = create_app(
server_settings=server_settings,
model_settings=model_settings,
)
uvicorn.run(
app,
host=os.getenv("HOST", server_settings.host),
port=int(os.getenv("PORT", server_settings.port)),
ssl_keyfile=server_settings.ssl_keyfile,
ssl_certfile=server_settings.ssl_certfile,
)
def contains_list_type(annotation) -> bool:
origin = getattr(annotation, '__origin__', None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(contains_list_type(arg) for arg in annotation.__args__)
else:
return False
def parse_bool_arg(arg):
if isinstance(arg, bytes):
arg = arg.decode('utf-8')
true_values = {'1', 'on', 't', 'true', 'y', 'yes'}
false_values = {'0', 'off', 'f', 'false', 'n', 'no'}
arg_str = str(arg).lower().strip()
if arg_str in true_values:
return True
elif arg_str in false_values:
return False
else:
raise ValueError(f'Invalid boolean argument: {arg}')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
for name, field in Settings.model_fields.items():
description = field.description
if field.default is not None and description is not None:
description += f" (default: {field.default})"
base_type = get_base_type(field.annotation) if field.annotation is not None else str
list_type = contains_list_type(field.annotation)
if base_type is not bool:
parser.add_argument(
f"--{name}",
dest=name,
nargs="*" if list_type else None,
type=base_type,
help=description,
)
if base_type is bool:
parser.add_argument(
f"--{name}",
dest=name,
type=parse_bool_arg,
help=f"{description}",
)
args = parser.parse_args()
settings = Settings(**{k: v for k, v in vars(args).items() if v is not None})
app = create_app(settings=settings)
uvicorn.run(
app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)),
ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile
)
main()

View file

@ -1,375 +1,120 @@
import sys
from __future__ import annotations
import os
import json
import traceback
import multiprocessing
import time
from re import compile, Match, Pattern
from threading import Lock
from functools import partial
from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict
from typing_extensions import TypedDict, Literal
from typing import Iterator, List, Optional, Union, Dict
import llama_cpp
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
from fastapi import (
Depends,
FastAPI,
APIRouter,
Request,
HTTPException,
status,
)
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from fastapi.security import HTTPBearer
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from sse_starlette.sse import EventSourceResponse
from starlette_context import plugins
from starlette_context.plugins import RequestIdPlugin # type: ignore
from starlette_context.middleware import RawContextMiddleware
import numpy as np
import numpy.typing as npt
# Disable warning for model and model_alias settings
BaseSettings.model_config["protected_namespaces"] = ()
class Settings(BaseSettings):
model: str = Field(
description="The path to the model to use for generating completions."
from llama_cpp.server.model import (
LlamaProxy,
)
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
from llama_cpp.server.settings import (
ConfigFileSettings,
Settings,
ModelSettings,
ServerSettings,
)
# Model Params
n_gpu_layers: int = Field(
default=0,
ge=-1,
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
)
main_gpu: int = Field(
default=0,
ge=0,
description="Main GPU to use.",
)
tensor_split: Optional[List[float]] = Field(
default=None,
description="Split layers across multiple GPUs in proportion.",
)
vocab_only: bool = Field(
default=False, description="Whether to only return the vocabulary."
)
use_mmap: bool = Field(
default=llama_cpp.llama_mmap_supported(),
description="Use mmap.",
)
use_mlock: bool = Field(
default=llama_cpp.llama_mlock_supported(),
description="Use mlock.",
)
# Context Params
seed: int = Field(
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=1,
description="The number of threads to use.",
)
n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
yarn_ext_factor: float = Field(default=-1.0)
yarn_attn_factor: float = Field(default=1.0)
yarn_beta_fast: float = Field(default=32.0)
yarn_beta_slow: float = Field(default=1.0)
yarn_orig_ctx: int = Field(default=0)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=True, description="Whether to use embeddings.")
offload_kqv: bool = Field(
default=False, description="Whether to offload kqv to the GPU."
)
# Sampling Params
last_n_tokens_size: int = Field(
default=64,
ge=0,
description="Last n tokens to keep for repeat penalty calculation.",
)
# LoRA Params
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
)
lora_path: Optional[str] = Field(
default=None,
description="Path to a LoRA file to apply to the model.",
)
# Backend Params
numa: bool = Field(
default=False,
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(
default=None,
description="Path to a CLIP model to use for multi-modal chat completion.",
)
# Cache Params
cache: bool = Field(
default=False,
description="Use a cache to reduce processing times for evaluated prompts.",
)
cache_type: Literal["ram", "disk"] = Field(
default="ram",
description="The type of cache to use. Only used if cache is True.",
)
cache_size: int = Field(
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
)
# Server Params
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
# SSL Params
ssl_keyfile: Optional[str] = Field(
default=None, description="SSL key file for HTTPS"
)
ssl_certfile: Optional[str] = Field(
default=None, description="SSL certificate file for HTTPS"
)
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for authentication. If set all requests need to be authenticated."
)
class ErrorResponse(TypedDict):
"""OpenAI style error response"""
message: str
type: str
param: Optional[str]
code: Optional[str]
class ErrorResponseFormatters:
"""Collection of formatters for error responses.
Args:
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
Request body
match (Match[str]): Match object from regex pattern
Returns:
Tuple[int, ErrorResponse]: Status code and error response
"""
@staticmethod
def context_length_exceeded(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error"""
context_window = int(match.group(2))
prompt_tokens = int(match.group(1))
completion_tokens = request.max_tokens
if hasattr(request, "messages"):
# Chat completion
message = (
"This model's maximum context length is {} tokens. "
"However, you requested {} tokens "
"({} in the messages, {} in the completion). "
"Please reduce the length of the messages or completion."
)
else:
# Text completion
message = (
"This model's maximum context length is {} tokens, "
"however you requested {} tokens "
"({} in your prompt; {} for the completion). "
"Please reduce your prompt; or completion length."
)
return 400, ErrorResponse(
message=message.format(
context_window,
completion_tokens + prompt_tokens,
prompt_tokens,
completion_tokens,
),
type="invalid_request_error",
param="messages",
code="context_length_exceeded",
)
@staticmethod
def model_not_found(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for model_not_found error"""
model_path = str(match.group(1))
message = f"The model `{model_path}` does not exist"
return 400, ErrorResponse(
message=message,
type="invalid_request_error",
param=None,
code="model_not_found",
)
class RouteErrorHandler(APIRoute):
"""Custom APIRoute that handles application errors and exceptions"""
# key: regex pattern for original error message from llama_cpp
# value: formatter function
pattern_and_formatters: Dict[
"Pattern",
Callable[
[
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
"Match[str]",
],
Tuple[int, ErrorResponse],
],
] = {
compile(
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
): ErrorResponseFormatters.context_length_exceeded,
compile(
r"Model path does not exist: (.+)"
): ErrorResponseFormatters.model_not_found,
}
def error_message_wrapper(
self,
error: Exception,
body: Optional[
Union[
"CreateChatCompletionRequest",
"CreateCompletionRequest",
"CreateEmbeddingRequest",
]
] = None,
) -> Tuple[int, ErrorResponse]:
"""Wraps error message in OpenAI style error response"""
print(f"Exception: {str(error)}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
if body is not None and isinstance(
body,
(
CreateCompletionRequest,
CreateChatCompletionRequest,
),
):
# When text completion or chat completion
for pattern, callback in self.pattern_and_formatters.items():
match = pattern.search(str(error))
if match is not None:
return callback(body, match)
# Wrap other errors as internal server error
return 500, ErrorResponse(
message=str(error),
type="internal_server_error",
param=None,
code=None,
)
def get_route_handler(
self,
) -> Callable[[Request], Coroutine[None, None, Response]]:
"""Defines custom route handler that catches exceptions and formats
in OpenAI style error response"""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
try:
start_sec = time.perf_counter()
response = await original_route_handler(request)
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
return response
except HTTPException as unauthorized:
# api key check failed
raise unauthorized
except Exception as exc:
json_body = await request.json()
try:
if "messages" in json_body:
# Chat completion
body: Optional[
Union[
CreateChatCompletionRequest,
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
]
] = CreateChatCompletionRequest(**json_body)
elif "prompt" in json_body:
# Text completion
body = CreateCompletionRequest(**json_body)
else:
# Embedding
body = CreateEmbeddingRequest(**json_body)
except Exception:
# Invalid request body
body = None
# Get proper error message from the exception
(
status_code,
error_message,
) = self.error_message_wrapper(error=exc, body=body)
return JSONResponse(
{"error": error_message},
status_code=status_code,
CreateChatCompletionRequest,
ModelList,
)
return custom_route_handler
from llama_cpp.server.errors import RouteErrorHandler
router = APIRouter(route_class=RouteErrorHandler)
settings: Optional[Settings] = None
llama: Optional[llama_cpp.Llama] = None
_server_settings: Optional[ServerSettings] = None
def create_app(settings: Optional[Settings] = None):
def set_server_settings(server_settings: ServerSettings):
global _server_settings
_server_settings = server_settings
def get_server_settings():
yield _server_settings
_llama_proxy: Optional[LlamaProxy] = None
llama_outer_lock = Lock()
llama_inner_lock = Lock()
def set_llama_proxy(model_settings: List[ModelSettings]):
global _llama_proxy
_llama_proxy = LlamaProxy(models=model_settings)
def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield _llama_proxy
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
def create_app(
settings: Settings | None = None,
server_settings: ServerSettings | None = None,
model_settings: List[ModelSettings] | None = None,
):
config_file = os.environ.get("CONFIG_FILE", None)
if config_file is not None:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
if server_settings is None and model_settings is None:
if settings is None:
settings = Settings()
server_settings = ServerSettings.model_validate(settings)
model_settings = [ModelSettings.model_validate(settings)]
middleware = [
Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))
]
assert (
server_settings is not None and model_settings is not None
), "server_settings and model_settings must be provided together"
set_server_settings(server_settings)
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
app = FastAPI(
middleware=middleware,
title="🦙 llama.cpp Python API",
@ -383,105 +128,13 @@ def create_app(settings: Optional[Settings] = None):
allow_headers=["*"],
)
app.include_router(router)
global llama
##
chat_handler = None
if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
##
assert model_settings is not None
set_llama_proxy(model_settings=model_settings)
llama = llama_cpp.Llama(
model_path=settings.model,
# Model Params
n_gpu_layers=settings.n_gpu_layers,
main_gpu=settings.main_gpu,
tensor_split=settings.tensor_split,
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
# Context Params
seed=settings.seed,
n_ctx=settings.n_ctx,
n_batch=settings.n_batch,
n_threads=settings.n_threads,
n_threads_batch=settings.n_threads_batch,
rope_scaling_type=settings.rope_scaling_type,
rope_freq_base=settings.rope_freq_base,
rope_freq_scale=settings.rope_freq_scale,
yarn_ext_factor=settings.yarn_ext_factor,
yarn_attn_factor=settings.yarn_attn_factor,
yarn_beta_fast=settings.yarn_beta_fast,
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
mul_mat_q=settings.mul_mat_q,
logits_all=settings.logits_all,
embedding=settings.embedding,
offload_kqv=settings.offload_kqv,
# Sampling Params
last_n_tokens_size=settings.last_n_tokens_size,
# LoRA Params
lora_base=settings.lora_base,
lora_path=settings.lora_path,
# Backend Params
numa=settings.numa,
# Chat Format Params
chat_format=settings.chat_format,
chat_handler=chat_handler,
# Misc
verbose=settings.verbose,
)
if settings.cache:
if settings.cache_type == "disk":
if settings.verbose:
print(f"Using disk cache with size {settings.cache_size}")
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
else:
if settings.verbose:
print(f"Using ram cache with size {settings.cache_size}")
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
llama.set_cache(cache)
def set_settings(_settings: Settings):
global settings
settings = _settings
set_settings(settings)
return app
llama_outer_lock = Lock()
llama_inner_lock = Lock()
def get_llama():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield llama
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
def get_settings():
yield settings
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream,
@ -493,7 +146,10 @@ async def get_event_publisher(
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if settings.interrupt_requests and llama_outer_lock.locked():
if (
next(get_server_settings()).interrupt_requests
and llama_outer_lock.locked()
):
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
@ -504,156 +160,6 @@ async def get_event_publisher(
raise e
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field(
default=16, ge=1, description="The maximum number of tokens to generate."
)
temperature_field = Field(
default=0.8,
ge=0.0,
le=2.0,
description="Adjust the randomness of the generated text.\n\n"
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
)
top_p_field = Field(
default=0.95,
ge=0.0,
le=1.0,
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
)
min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)
stop_field = Field(
default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
)
stream_field = Field(
default=False,
description="Whether to stream the results as they are generated. Useful for chatbots.",
)
top_k_field = Field(
default=40,
ge=0,
description="Limit the next token selection to the K most probable tokens.\n\n"
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
)
repeat_penalty_field = Field(
default=1.1,
ge=0.0,
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
)
presence_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
)
frequency_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
)
mirostat_mode_field = Field(
default=0,
ge=0,
le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
)
mirostat_tau_field = Field(
default=5.0,
ge=0.0,
le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
)
mirostat_eta_field = Field(
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)
grammar = Field(
default=None,
description="A CBNF grammar (as string) to be used for formatting the model's output.",
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="", description="The prompt to generate completions for."
)
suffix: Optional[str] = Field(
default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
)
max_tokens: int = max_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
)
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
logprobs: Optional[int] = Field(
default=None,
ge=0,
description="The number of logprobs to generate. If None, no logprobs are generated.",
)
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logprobs: Optional[int] = Field(None)
seed: Optional[int] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
best_of: Optional[int] = 1
user: Optional[str] = Field(default=None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
]
}
}
def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
@ -670,7 +176,10 @@ def _logit_bias_tokens_to_input_ids(
bearer_scheme = HTTPBearer(auto_error=False)
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
async def authenticate(
settings: Settings = Depends(get_server_settings),
authorization: Optional[str] = Depends(bearer_scheme),
):
# Skip API key check if it's not set in settings
if settings.api_key is None:
return True
@ -688,20 +197,28 @@ async def authenticate(settings: Settings = Depends(get_settings), authorization
@router.post(
"/v1/completions",
summary="Completion"
"/v1/completions", summary="Completion", dependencies=[Depends(authenticate)]
)
@router.post(
"/v1/engines/copilot-codex/completions",
include_in_schema=False,
dependencies=[Depends(authenticate)],
)
@router.post("/v1/engines/copilot-codex/completions", include_in_schema=False)
async def create_completion(
request: Request,
body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
authenticated: str = Depends(authenticate),
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> llama_cpp.Completion:
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
llama = llama_proxy(
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)
exclude = {
"n",
"best_of",
@ -749,124 +266,26 @@ async def create_completion(
return iterator_or_completion
class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] = Field(default=None)
model_config = {
"json_schema_extra": {
"examples": [
{
"input": "The food was delicious and the waiter...",
}
]
}
}
@router.post(
"/v1/embeddings",
summary="Embedding"
"/v1/embeddings", summary="Embedding", dependencies=[Depends(authenticate)]
)
async def create_embedding(
request: CreateEmbeddingRequest,
llama: llama_cpp.Llama = Depends(get_llama),
authenticated: str = Depends(authenticate),
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
):
return await run_in_threadpool(
llama.create_embedding, **request.model_dump(exclude={"user"})
llama_proxy(request.model).create_embedding,
**request.model_dump(exclude={"user"}),
)
class ChatCompletionRequestMessage(BaseModel):
role: Literal["system", "user", "assistant", "function"] = Field(
default="user", description="The role of the message."
)
content: Optional[str] = Field(
default="", description="The content of the message."
)
class CreateChatCompletionRequest(BaseModel):
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
default=[], description="A list of messages to generate completions for."
)
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
default=None,
description="A list of functions to apply to the generated completions.",
)
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
default=None,
description="A function to apply to the generated completions.",
)
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
default=None,
description="A list of tools to apply to the generated completions.",
)
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
default=None,
description="A tool to apply to the generated completions.",
) # TODO: verify
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
).model_dump(),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
).model_dump(),
]
}
]
}
}
@router.post(
"/v1/chat/completions",
summary="Chat"
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)]
)
async def create_chat_completion(
request: Request,
body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings),
authenticated: str = Depends(authenticate),
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> llama_cpp.ChatCompletion:
exclude = {
"n",
@ -874,7 +293,7 @@ async def create_chat_completion(
"user",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
@ -913,34 +332,19 @@ async def create_chat_completion(
return iterator_or_completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
@router.get("/v1/models", summary="Models")
@router.get("/v1/models", summary="Models", dependencies=[Depends(authenticate)])
async def get_models(
settings: Settings = Depends(get_settings),
authenticated: str = Depends(authenticate),
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> ModelList:
assert llama is not None
return {
"object": "list",
"data": [
{
"id": settings.model_alias
if settings.model_alias is not None
else llama.model_path,
"id": model_alias,
"object": "model",
"owned_by": "me",
"permissions": [],
}
for model_alias in llama_proxy
],
}

97
llama_cpp/server/cli.py Normal file
View file

@ -0,0 +1,97 @@
from __future__ import annotations
import argparse
from typing import List, Literal, Union, Any, Type, TypeVar
from pydantic import BaseModel
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
if getattr(annotation, "__origin__", None) is Literal:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return type(annotation.__args__[0]) # type: ignore
elif getattr(annotation, "__origin__", None) is Union:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
non_optional_args: List[Type[Any]] = [
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
]
if non_optional_args:
return _get_base_type(non_optional_args[0])
elif (
getattr(annotation, "__origin__", None) is list
or getattr(annotation, "__origin__", None) is List
):
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return _get_base_type(annotation.__args__[0]) # type: ignore
return annotation
def _contains_list_type(annotation: Type[Any] | None) -> bool:
origin = getattr(annotation, "__origin__", None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
else:
return False
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
if isinstance(arg, bytes):
arg = arg.decode("utf-8")
true_values = {"1", "on", "t", "true", "y", "yes"}
false_values = {"0", "off", "f", "false", "n", "no"}
arg_str = str(arg).lower().strip()
if arg_str in true_values:
return True
elif arg_str in false_values:
return False
else:
raise ValueError(f"Invalid boolean argument: {arg}")
def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
"""Add arguments from a pydantic model to an argparse parser."""
for name, field in model.model_fields.items():
description = field.description
if field.default and description and not field.is_required():
description += f" (default: {field.default})"
base_type = (
_get_base_type(field.annotation) if field.annotation is not None else str
)
list_type = _contains_list_type(field.annotation)
if base_type is not bool:
parser.add_argument(
f"--{name}",
dest=name,
nargs="*" if list_type else None,
type=base_type,
help=description,
)
if base_type is bool:
parser.add_argument(
f"--{name}",
dest=name,
type=_parse_bool_arg,
help=f"{description}",
)
T = TypeVar("T", bound=type[BaseModel])
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
"""Parse a pydantic model from an argparse namespace."""
return model(
**{
k: v
for k, v in vars(args).items()
if v is not None and k in model.model_fields
}
)

210
llama_cpp/server/errors.py Normal file
View file

@ -0,0 +1,210 @@
from __future__ import annotations
import sys
import traceback
import time
from re import compile, Match, Pattern
from typing import Callable, Coroutine, Optional, Tuple, Union, Dict
from typing_extensions import TypedDict
from fastapi import (
Request,
Response,
HTTPException,
)
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
CreateChatCompletionRequest,
)
class ErrorResponse(TypedDict):
"""OpenAI style error response"""
message: str
type: str
param: Optional[str]
code: Optional[str]
class ErrorResponseFormatters:
"""Collection of formatters for error responses.
Args:
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
Request body
match (Match[str]): Match object from regex pattern
Returns:
Tuple[int, ErrorResponse]: Status code and error response
"""
@staticmethod
def context_length_exceeded(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error"""
context_window = int(match.group(2))
prompt_tokens = int(match.group(1))
completion_tokens = request.max_tokens
if hasattr(request, "messages"):
# Chat completion
message = (
"This model's maximum context length is {} tokens. "
"However, you requested {} tokens "
"({} in the messages, {} in the completion). "
"Please reduce the length of the messages or completion."
)
else:
# Text completion
message = (
"This model's maximum context length is {} tokens, "
"however you requested {} tokens "
"({} in your prompt; {} for the completion). "
"Please reduce your prompt; or completion length."
)
return 400, ErrorResponse(
message=message.format(
context_window,
completion_tokens + prompt_tokens,
prompt_tokens,
completion_tokens,
), # type: ignore
type="invalid_request_error",
param="messages",
code="context_length_exceeded",
)
@staticmethod
def model_not_found(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for model_not_found error"""
model_path = str(match.group(1))
message = f"The model `{model_path}` does not exist"
return 400, ErrorResponse(
message=message,
type="invalid_request_error",
param=None,
code="model_not_found",
)
class RouteErrorHandler(APIRoute):
"""Custom APIRoute that handles application errors and exceptions"""
# key: regex pattern for original error message from llama_cpp
# value: formatter function
pattern_and_formatters: Dict[
"Pattern[str]",
Callable[
[
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
"Match[str]",
],
Tuple[int, ErrorResponse],
],
] = {
compile(
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
): ErrorResponseFormatters.context_length_exceeded,
compile(
r"Model path does not exist: (.+)"
): ErrorResponseFormatters.model_not_found,
}
def error_message_wrapper(
self,
error: Exception,
body: Optional[
Union[
"CreateChatCompletionRequest",
"CreateCompletionRequest",
"CreateEmbeddingRequest",
]
] = None,
) -> Tuple[int, ErrorResponse]:
"""Wraps error message in OpenAI style error response"""
print(f"Exception: {str(error)}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
if body is not None and isinstance(
body,
(
CreateCompletionRequest,
CreateChatCompletionRequest,
),
):
# When text completion or chat completion
for pattern, callback in self.pattern_and_formatters.items():
match = pattern.search(str(error))
if match is not None:
return callback(body, match)
# Wrap other errors as internal server error
return 500, ErrorResponse(
message=str(error),
type="internal_server_error",
param=None,
code=None,
)
def get_route_handler(
self,
) -> Callable[[Request], Coroutine[None, None, Response]]:
"""Defines custom route handler that catches exceptions and formats
in OpenAI style error response"""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
try:
start_sec = time.perf_counter()
response = await original_route_handler(request)
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
return response
except HTTPException as unauthorized:
# api key check failed
raise unauthorized
except Exception as exc:
json_body = await request.json()
try:
if "messages" in json_body:
# Chat completion
body: Optional[
Union[
CreateChatCompletionRequest,
CreateCompletionRequest,
CreateEmbeddingRequest,
]
] = CreateChatCompletionRequest(**json_body)
elif "prompt" in json_body:
# Text completion
body = CreateCompletionRequest(**json_body)
else:
# Embedding
body = CreateEmbeddingRequest(**json_body)
except Exception:
# Invalid request body
body = None
# Get proper error message from the exception
(
status_code,
error_message,
) = self.error_message_wrapper(error=exc, body=body)
return JSONResponse(
{"error": error_message},
status_code=status_code,
)
return custom_route_handler

126
llama_cpp/server/model.py Normal file
View file

@ -0,0 +1,126 @@
from __future__ import annotations
from typing import Optional, Union, List
import llama_cpp
from llama_cpp.server.settings import ModelSettings
class LlamaProxy:
def __init__(self, models: List[ModelSettings]) -> None:
assert len(models) > 0, "No models provided!"
self._model_settings_dict: dict[str, ModelSettings] = {}
for model in models:
if not model.model_alias:
model.model_alias = model.model
self._model_settings_dict[model.model_alias] = model
self._current_model: Optional[llama_cpp.Llama] = None
self._current_model_alias: Optional[str] = None
self._default_model_settings: ModelSettings = models[0]
self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore
# Load default model
self._current_model = self.load_llama_from_model_settings(
self._default_model_settings
)
self._current_model_alias = self._default_model_alias
def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
if model is None:
model = self._default_model_alias
if model not in self._model_settings_dict:
model = self._default_model_alias
if model == self._current_model_alias:
if self._current_model is not None:
return self._current_model
self._current_model = None
settings = self._model_settings_dict[model]
self._current_model = self.load_llama_from_model_settings(settings)
self._current_model_alias = model
return self._current_model
def __getitem__(self, model: str):
return self._model_settings_dict[model].model_dump()
def __setitem__(self, model: str, settings: Union[ModelSettings, str, bytes]):
if isinstance(settings, (bytes, str)):
settings = ModelSettings.model_validate_json(settings)
self._model_settings_dict[model] = settings
def __iter__(self):
for model in self._model_settings_dict:
yield model
def free(self):
if self._current_model:
del self._current_model
@staticmethod
def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
chat_handler = None
if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None, "clip model not found"
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
_model = llama_cpp.Llama(
model_path=settings.model,
# Model Params
n_gpu_layers=settings.n_gpu_layers,
main_gpu=settings.main_gpu,
tensor_split=settings.tensor_split,
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
# Context Params
seed=settings.seed,
n_ctx=settings.n_ctx,
n_batch=settings.n_batch,
n_threads=settings.n_threads,
n_threads_batch=settings.n_threads_batch,
rope_scaling_type=settings.rope_scaling_type,
rope_freq_base=settings.rope_freq_base,
rope_freq_scale=settings.rope_freq_scale,
yarn_ext_factor=settings.yarn_ext_factor,
yarn_attn_factor=settings.yarn_attn_factor,
yarn_beta_fast=settings.yarn_beta_fast,
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
mul_mat_q=settings.mul_mat_q,
logits_all=settings.logits_all,
embedding=settings.embedding,
offload_kqv=settings.offload_kqv,
# Sampling Params
last_n_tokens_size=settings.last_n_tokens_size,
# LoRA Params
lora_base=settings.lora_base,
lora_path=settings.lora_path,
# Backend Params
numa=settings.numa,
# Chat Format Params
chat_format=settings.chat_format,
chat_handler=chat_handler,
# Misc
verbose=settings.verbose,
)
if settings.cache:
if settings.cache_type == "disk":
if settings.verbose:
print(f"Using disk cache with size {settings.cache_size}")
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
else:
if settings.verbose:
print(f"Using ram cache with size {settings.cache_size}")
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
_model.set_cache(cache)
return _model

View file

@ -0,0 +1,161 @@
from __future__ import annotations
import multiprocessing
from typing import Optional, List, Literal
from pydantic import Field
from pydantic_settings import BaseSettings
import llama_cpp
# Disable warning for model and model_alias settings
BaseSettings.model_config["protected_namespaces"] = ()
class ModelSettings(BaseSettings):
model: str = Field(
description="The path to the model to use for generating completions."
)
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
# Model Params
n_gpu_layers: int = Field(
default=0,
ge=-1,
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
)
main_gpu: int = Field(
default=0,
ge=0,
description="Main GPU to use.",
)
tensor_split: Optional[List[float]] = Field(
default=None,
description="Split layers across multiple GPUs in proportion.",
)
vocab_only: bool = Field(
default=False, description="Whether to only return the vocabulary."
)
use_mmap: bool = Field(
default=llama_cpp.llama_mmap_supported(),
description="Use mmap.",
)
use_mlock: bool = Field(
default=llama_cpp.llama_mlock_supported(),
description="Use mlock.",
)
# Context Params
seed: int = Field(
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=1,
description="The number of threads to use.",
)
n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
yarn_ext_factor: float = Field(default=-1.0)
yarn_attn_factor: float = Field(default=1.0)
yarn_beta_fast: float = Field(default=32.0)
yarn_beta_slow: float = Field(default=1.0)
yarn_orig_ctx: int = Field(default=0)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=True, description="Whether to use embeddings.")
offload_kqv: bool = Field(
default=False, description="Whether to offload kqv to the GPU."
)
# Sampling Params
last_n_tokens_size: int = Field(
default=64,
ge=0,
description="Last n tokens to keep for repeat penalty calculation.",
)
# LoRA Params
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
)
lora_path: Optional[str] = Field(
default=None,
description="Path to a LoRA file to apply to the model.",
)
# Backend Params
numa: bool = Field(
default=False,
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(
default=None,
description="Path to a CLIP model to use for multi-modal chat completion.",
)
# Cache Params
cache: bool = Field(
default=False,
description="Use a cache to reduce processing times for evaluated prompts.",
)
cache_type: Literal["ram", "disk"] = Field(
default="ram",
description="The type of cache to use. Only used if cache is True.",
)
cache_size: int = Field(
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
)
class ServerSettings(BaseSettings):
# Uvicorn Settings
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
ssl_keyfile: Optional[str] = Field(
default=None, description="SSL key file for HTTPS"
)
ssl_certfile: Optional[str] = Field(
default=None, description="SSL certificate file for HTTPS"
)
# FastAPI Settings
api_key: Optional[str] = Field(
default=None,
description="API key for authentication. If set all requests need to be authenticated.",
)
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
)
class Settings(ServerSettings, ModelSettings):
pass
class ConfigFileSettings(ServerSettings):
models: List[ModelSettings] = Field(
default=[], description="Model configs, overwrites default config"
)

264
llama_cpp/server/types.py Normal file
View file

@ -0,0 +1,264 @@
from __future__ import annotations
from typing import List, Optional, Union, Dict
from typing_extensions import TypedDict, Literal
from pydantic import BaseModel, Field
import llama_cpp
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field(
default=16, ge=1, description="The maximum number of tokens to generate."
)
temperature_field = Field(
default=0.8,
ge=0.0,
le=2.0,
description="Adjust the randomness of the generated text.\n\n"
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
)
top_p_field = Field(
default=0.95,
ge=0.0,
le=1.0,
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
)
min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)
stop_field = Field(
default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
)
stream_field = Field(
default=False,
description="Whether to stream the results as they are generated. Useful for chatbots.",
)
top_k_field = Field(
default=40,
ge=0,
description="Limit the next token selection to the K most probable tokens.\n\n"
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
)
repeat_penalty_field = Field(
default=1.1,
ge=0.0,
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
)
presence_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
)
frequency_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
)
mirostat_mode_field = Field(
default=0,
ge=0,
le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
)
mirostat_tau_field = Field(
default=5.0,
ge=0.0,
le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
)
mirostat_eta_field = Field(
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)
grammar = Field(
default=None,
description="A CBNF grammar (as string) to be used for formatting the model's output.",
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="", description="The prompt to generate completions for."
)
suffix: Optional[str] = Field(
default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
)
max_tokens: int = max_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
)
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
logprobs: Optional[int] = Field(
default=None,
ge=0,
description="The number of logprobs to generate. If None, no logprobs are generated.",
)
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logprobs: Optional[int] = Field(None)
seed: Optional[int] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
best_of: Optional[int] = 1
user: Optional[str] = Field(default=None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
]
}
}
class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] = Field(default=None)
model_config = {
"json_schema_extra": {
"examples": [
{
"input": "The food was delicious and the waiter...",
}
]
}
}
class ChatCompletionRequestMessage(BaseModel):
role: Literal["system", "user", "assistant", "function"] = Field(
default="user", description="The role of the message."
)
content: Optional[str] = Field(
default="", description="The content of the message."
)
class CreateChatCompletionRequest(BaseModel):
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
default=[], description="A list of messages to generate completions for."
)
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
default=None,
description="A list of functions to apply to the generated completions.",
)
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
default=None,
description="A function to apply to the generated completions.",
)
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
default=None,
description="A list of tools to apply to the generated completions.",
)
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
default=None,
description="A tool to apply to the generated completions.",
) # TODO: verify
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
).model_dump(),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
).model_dump(),
]
}
]
}
}
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]