[Feat] Multi model support (#931)
* Update Llama class to handle chat_format & caching * Add settings.py * Add util.py & update __main__.py * multimodel * update settings.py * cleanup * delete util.py * Fix /v1/models endpoint * MultiLlama now iterable, app check-alive on "/" * instant model init if file is given * backward compability * revert model param mandatory * fix error * handle individual model config json * refactor * revert chathandler/clip_model changes * handle chat_handler in MulitLlama() * split settings into server/llama * reduce global vars * Update LlamaProxy to handle config files * Add free method to LlamaProxy * update arg parsers & install server alias * refactor cache settings * change server executable name * better var name * whitespace * Revert "whitespace" This reverts commit bc5cf51c64a95bfc9926e1bc58166059711a1cd8. * remove exe_name * Fix merge bugs * Fix type annotations * Fix type annotations * Fix uvicorn app factory * Fix settings * Refactor server * Remove formatting fix * Format * Use default model if not found in model settings * Fix * Cleanup * Fix * Fix * Remove unnused CommandLineSettings * Cleanup * Support default name for copilot-codex models --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
4a85442c35
commit
12b7f2f4e9
7 changed files with 1042 additions and 793 deletions
|
@ -9,7 +9,7 @@ export MODEL=../models/7B/...
|
|||
|
||||
Then run:
|
||||
```
|
||||
uvicorn llama_cpp.server.app:app --reload
|
||||
uvicorn llama_cpp.server.app:create_app --reload
|
||||
```
|
||||
|
||||
or
|
||||
|
@ -21,81 +21,68 @@ python3 -m llama_cpp.server
|
|||
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import create_app, Settings
|
||||
from llama_cpp.server.app import create_app
|
||||
from llama_cpp.server.settings import (
|
||||
Settings,
|
||||
ServerSettings,
|
||||
ModelSettings,
|
||||
ConfigFileSettings,
|
||||
)
|
||||
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
|
||||
|
||||
def get_base_type(annotation):
|
||||
if getattr(annotation, '__origin__', None) is Literal:
|
||||
return type(annotation.__args__[0])
|
||||
elif getattr(annotation, '__origin__', None) is Union:
|
||||
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
|
||||
if non_optional_args:
|
||||
return get_base_type(non_optional_args[0])
|
||||
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
|
||||
return get_base_type(annotation.__args__[0])
|
||||
|
||||
def main():
|
||||
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
add_args_from_model(parser, Settings)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
help="Path to a config file to load.",
|
||||
)
|
||||
server_settings: ServerSettings | None = None
|
||||
model_settings: list[ModelSettings] = []
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Load server settings from config_file if provided
|
||||
config_file = os.environ.get("CONFIG_FILE", args.config_file)
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
else:
|
||||
return annotation
|
||||
server_settings = parse_model_from_args(ServerSettings, args)
|
||||
model_settings = [parse_model_from_args(ModelSettings, args)]
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
assert server_settings is not None
|
||||
assert model_settings is not None
|
||||
app = create_app(
|
||||
server_settings=server_settings,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=os.getenv("HOST", server_settings.host),
|
||||
port=int(os.getenv("PORT", server_settings.port)),
|
||||
ssl_keyfile=server_settings.ssl_keyfile,
|
||||
ssl_certfile=server_settings.ssl_certfile,
|
||||
)
|
||||
|
||||
def contains_list_type(annotation) -> bool:
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
|
||||
if origin is list or origin is List:
|
||||
return True
|
||||
elif origin in (Literal, Union):
|
||||
return any(contains_list_type(arg) for arg in annotation.__args__)
|
||||
else:
|
||||
return False
|
||||
|
||||
def parse_bool_arg(arg):
|
||||
if isinstance(arg, bytes):
|
||||
arg = arg.decode('utf-8')
|
||||
|
||||
true_values = {'1', 'on', 't', 'true', 'y', 'yes'}
|
||||
false_values = {'0', 'off', 'f', 'false', 'n', 'no'}
|
||||
|
||||
arg_str = str(arg).lower().strip()
|
||||
|
||||
if arg_str in true_values:
|
||||
return True
|
||||
elif arg_str in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f'Invalid boolean argument: {arg}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
for name, field in Settings.model_fields.items():
|
||||
description = field.description
|
||||
if field.default is not None and description is not None:
|
||||
description += f" (default: {field.default})"
|
||||
base_type = get_base_type(field.annotation) if field.annotation is not None else str
|
||||
list_type = contains_list_type(field.annotation)
|
||||
if base_type is not bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*" if list_type else None,
|
||||
type=base_type,
|
||||
help=description,
|
||||
)
|
||||
if base_type is bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=parse_bool_arg,
|
||||
help=f"{description}",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
settings = Settings(**{k: v for k, v in vars(args).items() if v is not None})
|
||||
app = create_app(settings=settings)
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)),
|
||||
ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile
|
||||
)
|
||||
main()
|
||||
|
|
|
@ -1,375 +1,120 @@
|
|||
import sys
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
import multiprocessing
|
||||
import time
|
||||
from re import compile, Match, Pattern
|
||||
|
||||
from threading import Lock
|
||||
from functools import partial
|
||||
from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
from typing import Iterator, List, Optional, Union, Dict
|
||||
|
||||
import llama_cpp
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
APIRouter,
|
||||
Request,
|
||||
HTTPException,
|
||||
status,
|
||||
)
|
||||
from fastapi.middleware import Middleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette_context import plugins
|
||||
from starlette_context.plugins import RequestIdPlugin # type: ignore
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
# Disable warning for model and model_alias settings
|
||||
BaseSettings.model_config["protected_namespaces"] = ()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str = Field(
|
||||
description="The path to the model to use for generating completions."
|
||||
from llama_cpp.server.model import (
|
||||
LlamaProxy,
|
||||
)
|
||||
model_alias: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
from llama_cpp.server.settings import (
|
||||
ConfigFileSettings,
|
||||
Settings,
|
||||
ModelSettings,
|
||||
ServerSettings,
|
||||
)
|
||||
# Model Params
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
|
||||
)
|
||||
main_gpu: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Main GPU to use.",
|
||||
)
|
||||
tensor_split: Optional[List[float]] = Field(
|
||||
default=None,
|
||||
description="Split layers across multiple GPUs in proportion.",
|
||||
)
|
||||
vocab_only: bool = Field(
|
||||
default=False, description="Whether to only return the vocabulary."
|
||||
)
|
||||
use_mmap: bool = Field(
|
||||
default=llama_cpp.llama_mmap_supported(),
|
||||
description="Use mmap.",
|
||||
)
|
||||
use_mlock: bool = Field(
|
||||
default=llama_cpp.llama_mlock_supported(),
|
||||
description="Use mlock.",
|
||||
)
|
||||
# Context Params
|
||||
seed: int = Field(
|
||||
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
|
||||
)
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
n_threads: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=1,
|
||||
description="The number of threads to use.",
|
||||
)
|
||||
n_threads_batch: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=0,
|
||||
description="The number of threads to use when batch processing.",
|
||||
)
|
||||
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
|
||||
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
|
||||
rope_freq_scale: float = Field(
|
||||
default=0.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
yarn_ext_factor: float = Field(default=-1.0)
|
||||
yarn_attn_factor: float = Field(default=1.0)
|
||||
yarn_beta_fast: float = Field(default=32.0)
|
||||
yarn_beta_slow: float = Field(default=1.0)
|
||||
yarn_orig_ctx: int = Field(default=0)
|
||||
mul_mat_q: bool = Field(
|
||||
default=True, description="if true, use experimental mul_mat_q kernels"
|
||||
)
|
||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
||||
offload_kqv: bool = Field(
|
||||
default=False, description="Whether to offload kqv to the GPU."
|
||||
)
|
||||
# Sampling Params
|
||||
last_n_tokens_size: int = Field(
|
||||
default=64,
|
||||
ge=0,
|
||||
description="Last n tokens to keep for repeat penalty calculation.",
|
||||
)
|
||||
# LoRA Params
|
||||
lora_base: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
|
||||
)
|
||||
lora_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a LoRA file to apply to the model.",
|
||||
)
|
||||
# Backend Params
|
||||
numa: bool = Field(
|
||||
default=False,
|
||||
description="Enable NUMA support.",
|
||||
)
|
||||
# Chat Format Params
|
||||
chat_format: str = Field(
|
||||
default="llama-2",
|
||||
description="Chat format to use.",
|
||||
)
|
||||
clip_model_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a CLIP model to use for multi-modal chat completion.",
|
||||
)
|
||||
# Cache Params
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
description="Use a cache to reduce processing times for evaluated prompts.",
|
||||
)
|
||||
cache_type: Literal["ram", "disk"] = Field(
|
||||
default="ram",
|
||||
description="The type of cache to use. Only used if cache is True.",
|
||||
)
|
||||
cache_size: int = Field(
|
||||
default=2 << 30,
|
||||
description="The size of the cache in bytes. Only used if cache is True.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
)
|
||||
# Server Params
|
||||
host: str = Field(default="localhost", description="Listen address")
|
||||
port: int = Field(default=8000, description="Listen port")
|
||||
# SSL Params
|
||||
ssl_keyfile: Optional[str] = Field(
|
||||
default=None, description="SSL key file for HTTPS"
|
||||
)
|
||||
ssl_certfile: Optional[str] = Field(
|
||||
default=None, description="SSL certificate file for HTTPS"
|
||||
)
|
||||
interrupt_requests: bool = Field(
|
||||
default=True,
|
||||
description="Whether to interrupt requests when a new request is received.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authentication. If set all requests need to be authenticated."
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
"""OpenAI style error response"""
|
||||
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str]
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class ErrorResponseFormatters:
|
||||
"""Collection of formatters for error responses.
|
||||
|
||||
Args:
|
||||
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||
Request body
|
||||
match (Match[str]): Match object from regex pattern
|
||||
|
||||
Returns:
|
||||
Tuple[int, ErrorResponse]: Status code and error response
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def context_length_exceeded(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for context length exceeded error"""
|
||||
|
||||
context_window = int(match.group(2))
|
||||
prompt_tokens = int(match.group(1))
|
||||
completion_tokens = request.max_tokens
|
||||
if hasattr(request, "messages"):
|
||||
# Chat completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens. "
|
||||
"However, you requested {} tokens "
|
||||
"({} in the messages, {} in the completion). "
|
||||
"Please reduce the length of the messages or completion."
|
||||
)
|
||||
else:
|
||||
# Text completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens, "
|
||||
"however you requested {} tokens "
|
||||
"({} in your prompt; {} for the completion). "
|
||||
"Please reduce your prompt; or completion length."
|
||||
)
|
||||
return 400, ErrorResponse(
|
||||
message=message.format(
|
||||
context_window,
|
||||
completion_tokens + prompt_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
),
|
||||
type="invalid_request_error",
|
||||
param="messages",
|
||||
code="context_length_exceeded",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def model_not_found(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for model_not_found error"""
|
||||
|
||||
model_path = str(match.group(1))
|
||||
message = f"The model `{model_path}` does not exist"
|
||||
return 400, ErrorResponse(
|
||||
message=message,
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code="model_not_found",
|
||||
)
|
||||
|
||||
|
||||
class RouteErrorHandler(APIRoute):
|
||||
"""Custom APIRoute that handles application errors and exceptions"""
|
||||
|
||||
# key: regex pattern for original error message from llama_cpp
|
||||
# value: formatter function
|
||||
pattern_and_formatters: Dict[
|
||||
"Pattern",
|
||||
Callable[
|
||||
[
|
||||
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
"Match[str]",
|
||||
],
|
||||
Tuple[int, ErrorResponse],
|
||||
],
|
||||
] = {
|
||||
compile(
|
||||
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||
): ErrorResponseFormatters.context_length_exceeded,
|
||||
compile(
|
||||
r"Model path does not exist: (.+)"
|
||||
): ErrorResponseFormatters.model_not_found,
|
||||
}
|
||||
|
||||
def error_message_wrapper(
|
||||
self,
|
||||
error: Exception,
|
||||
body: Optional[
|
||||
Union[
|
||||
"CreateChatCompletionRequest",
|
||||
"CreateCompletionRequest",
|
||||
"CreateEmbeddingRequest",
|
||||
]
|
||||
] = None,
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Wraps error message in OpenAI style error response"""
|
||||
print(f"Exception: {str(error)}", file=sys.stderr)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
if body is not None and isinstance(
|
||||
body,
|
||||
(
|
||||
CreateCompletionRequest,
|
||||
CreateChatCompletionRequest,
|
||||
),
|
||||
):
|
||||
# When text completion or chat completion
|
||||
for pattern, callback in self.pattern_and_formatters.items():
|
||||
match = pattern.search(str(error))
|
||||
if match is not None:
|
||||
return callback(body, match)
|
||||
|
||||
# Wrap other errors as internal server error
|
||||
return 500, ErrorResponse(
|
||||
message=str(error),
|
||||
type="internal_server_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
def get_route_handler(
|
||||
self,
|
||||
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||
"""Defines custom route handler that catches exceptions and formats
|
||||
in OpenAI style error response"""
|
||||
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
try:
|
||||
start_sec = time.perf_counter()
|
||||
response = await original_route_handler(request)
|
||||
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
||||
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
||||
return response
|
||||
except HTTPException as unauthorized:
|
||||
# api key check failed
|
||||
raise unauthorized
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
if "messages" in json_body:
|
||||
# Chat completion
|
||||
body: Optional[
|
||||
Union[
|
||||
CreateChatCompletionRequest,
|
||||
from llama_cpp.server.types import (
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
]
|
||||
] = CreateChatCompletionRequest(**json_body)
|
||||
elif "prompt" in json_body:
|
||||
# Text completion
|
||||
body = CreateCompletionRequest(**json_body)
|
||||
else:
|
||||
# Embedding
|
||||
body = CreateEmbeddingRequest(**json_body)
|
||||
except Exception:
|
||||
# Invalid request body
|
||||
body = None
|
||||
|
||||
# Get proper error message from the exception
|
||||
(
|
||||
status_code,
|
||||
error_message,
|
||||
) = self.error_message_wrapper(error=exc, body=body)
|
||||
return JSONResponse(
|
||||
{"error": error_message},
|
||||
status_code=status_code,
|
||||
CreateChatCompletionRequest,
|
||||
ModelList,
|
||||
)
|
||||
|
||||
return custom_route_handler
|
||||
from llama_cpp.server.errors import RouteErrorHandler
|
||||
|
||||
|
||||
router = APIRouter(route_class=RouteErrorHandler)
|
||||
|
||||
settings: Optional[Settings] = None
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
_server_settings: Optional[ServerSettings] = None
|
||||
|
||||
|
||||
def create_app(settings: Optional[Settings] = None):
|
||||
def set_server_settings(server_settings: ServerSettings):
|
||||
global _server_settings
|
||||
_server_settings = server_settings
|
||||
|
||||
|
||||
def get_server_settings():
|
||||
yield _server_settings
|
||||
|
||||
|
||||
_llama_proxy: Optional[LlamaProxy] = None
|
||||
|
||||
llama_outer_lock = Lock()
|
||||
llama_inner_lock = Lock()
|
||||
|
||||
|
||||
def set_llama_proxy(model_settings: List[ModelSettings]):
|
||||
global _llama_proxy
|
||||
_llama_proxy = LlamaProxy(models=model_settings)
|
||||
|
||||
|
||||
def get_llama_proxy():
|
||||
# NOTE: This double lock allows the currently streaming llama model to
|
||||
# check if any other requests are pending in the same thread and cancel
|
||||
# the stream if so.
|
||||
llama_outer_lock.acquire()
|
||||
release_outer_lock = True
|
||||
try:
|
||||
llama_inner_lock.acquire()
|
||||
try:
|
||||
llama_outer_lock.release()
|
||||
release_outer_lock = False
|
||||
yield _llama_proxy
|
||||
finally:
|
||||
llama_inner_lock.release()
|
||||
finally:
|
||||
if release_outer_lock:
|
||||
llama_outer_lock.release()
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: Settings | None = None,
|
||||
server_settings: ServerSettings | None = None,
|
||||
model_settings: List[ModelSettings] | None = None,
|
||||
):
|
||||
config_file = os.environ.get("CONFIG_FILE", None)
|
||||
if config_file is not None:
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
|
||||
if server_settings is None and model_settings is None:
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
server_settings = ServerSettings.model_validate(settings)
|
||||
model_settings = [ModelSettings.model_validate(settings)]
|
||||
|
||||
middleware = [
|
||||
Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))
|
||||
]
|
||||
assert (
|
||||
server_settings is not None and model_settings is not None
|
||||
), "server_settings and model_settings must be provided together"
|
||||
|
||||
set_server_settings(server_settings)
|
||||
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
|
||||
app = FastAPI(
|
||||
middleware=middleware,
|
||||
title="🦙 llama.cpp Python API",
|
||||
|
@ -383,105 +128,13 @@ def create_app(settings: Optional[Settings] = None):
|
|||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(router)
|
||||
global llama
|
||||
|
||||
##
|
||||
chat_handler = None
|
||||
if settings.chat_format == "llava-1-5":
|
||||
assert settings.clip_model_path is not None
|
||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
##
|
||||
assert model_settings is not None
|
||||
set_llama_proxy(model_settings=model_settings)
|
||||
|
||||
llama = llama_cpp.Llama(
|
||||
model_path=settings.model,
|
||||
# Model Params
|
||||
n_gpu_layers=settings.n_gpu_layers,
|
||||
main_gpu=settings.main_gpu,
|
||||
tensor_split=settings.tensor_split,
|
||||
vocab_only=settings.vocab_only,
|
||||
use_mmap=settings.use_mmap,
|
||||
use_mlock=settings.use_mlock,
|
||||
# Context Params
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
n_batch=settings.n_batch,
|
||||
n_threads=settings.n_threads,
|
||||
n_threads_batch=settings.n_threads_batch,
|
||||
rope_scaling_type=settings.rope_scaling_type,
|
||||
rope_freq_base=settings.rope_freq_base,
|
||||
rope_freq_scale=settings.rope_freq_scale,
|
||||
yarn_ext_factor=settings.yarn_ext_factor,
|
||||
yarn_attn_factor=settings.yarn_attn_factor,
|
||||
yarn_beta_fast=settings.yarn_beta_fast,
|
||||
yarn_beta_slow=settings.yarn_beta_slow,
|
||||
yarn_orig_ctx=settings.yarn_orig_ctx,
|
||||
mul_mat_q=settings.mul_mat_q,
|
||||
logits_all=settings.logits_all,
|
||||
embedding=settings.embedding,
|
||||
offload_kqv=settings.offload_kqv,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
lora_base=settings.lora_base,
|
||||
lora_path=settings.lora_path,
|
||||
# Backend Params
|
||||
numa=settings.numa,
|
||||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
if settings.cache:
|
||||
if settings.cache_type == "disk":
|
||||
if settings.verbose:
|
||||
print(f"Using disk cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
|
||||
else:
|
||||
if settings.verbose:
|
||||
print(f"Using ram cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||
|
||||
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
|
||||
llama.set_cache(cache)
|
||||
|
||||
def set_settings(_settings: Settings):
|
||||
global settings
|
||||
settings = _settings
|
||||
|
||||
set_settings(settings)
|
||||
return app
|
||||
|
||||
|
||||
llama_outer_lock = Lock()
|
||||
llama_inner_lock = Lock()
|
||||
|
||||
|
||||
def get_llama():
|
||||
# NOTE: This double lock allows the currently streaming llama model to
|
||||
# check if any other requests are pending in the same thread and cancel
|
||||
# the stream if so.
|
||||
llama_outer_lock.acquire()
|
||||
release_outer_lock = True
|
||||
try:
|
||||
llama_inner_lock.acquire()
|
||||
try:
|
||||
llama_outer_lock.release()
|
||||
release_outer_lock = False
|
||||
yield llama
|
||||
finally:
|
||||
llama_inner_lock.release()
|
||||
finally:
|
||||
if release_outer_lock:
|
||||
llama_outer_lock.release()
|
||||
|
||||
|
||||
def get_settings():
|
||||
yield settings
|
||||
|
||||
|
||||
async def get_event_publisher(
|
||||
request: Request,
|
||||
inner_send_chan: MemoryObjectSendStream,
|
||||
|
@ -493,7 +146,10 @@ async def get_event_publisher(
|
|||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
if (
|
||||
next(get_server_settings()).interrupt_requests
|
||||
and llama_outer_lock.locked()
|
||||
):
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
|
@ -504,156 +160,6 @@ async def get_event_publisher(
|
|||
raise e
|
||||
|
||||
|
||||
model_field = Field(
|
||||
description="The model to use for generating completions.", default=None
|
||||
)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||
)
|
||||
|
||||
temperature_field = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Adjust the randomness of the generated text.\n\n"
|
||||
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
|
||||
)
|
||||
|
||||
top_p_field = Field(
|
||||
default=0.95,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
|
||||
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
||||
)
|
||||
|
||||
min_p_field = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Sets a minimum base probability threshold for token selection.\n\n"
|
||||
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
|
||||
)
|
||||
|
||||
stop_field = Field(
|
||||
default=None,
|
||||
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
||||
)
|
||||
|
||||
stream_field = Field(
|
||||
default=False,
|
||||
description="Whether to stream the results as they are generated. Useful for chatbots.",
|
||||
)
|
||||
|
||||
top_k_field = Field(
|
||||
default=40,
|
||||
ge=0,
|
||||
description="Limit the next token selection to the K most probable tokens.\n\n"
|
||||
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
|
||||
)
|
||||
|
||||
repeat_penalty_field = Field(
|
||||
default=1.1,
|
||||
ge=0.0,
|
||||
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
|
||||
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
|
||||
)
|
||||
|
||||
presence_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||||
)
|
||||
|
||||
frequency_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||
)
|
||||
|
||||
mirostat_mode_field = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=2,
|
||||
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
|
||||
)
|
||||
|
||||
mirostat_tau_field = Field(
|
||||
default=5.0,
|
||||
ge=0.0,
|
||||
le=10.0,
|
||||
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
|
||||
)
|
||||
|
||||
mirostat_eta_field = Field(
|
||||
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
|
||||
)
|
||||
|
||||
grammar = Field(
|
||||
default=None,
|
||||
description="A CBNF grammar (as string) to be used for formatting the model's output.",
|
||||
)
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]] = Field(
|
||||
default="", description="The prompt to generate completions for."
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
|
||||
)
|
||||
max_tokens: int = max_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="The number of logprobs to generate. If None, no logprobs are generated.",
|
||||
)
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logprobs: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
best_of: Optional[int] = 1
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
||||
"stop": ["\n", "###"],
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _logit_bias_tokens_to_input_ids(
|
||||
llama: llama_cpp.Llama,
|
||||
logit_bias: Dict[str, float],
|
||||
|
@ -670,7 +176,10 @@ def _logit_bias_tokens_to_input_ids(
|
|||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
|
||||
async def authenticate(
|
||||
settings: Settings = Depends(get_server_settings),
|
||||
authorization: Optional[str] = Depends(bearer_scheme),
|
||||
):
|
||||
# Skip API key check if it's not set in settings
|
||||
if settings.api_key is None:
|
||||
return True
|
||||
|
@ -688,20 +197,28 @@ async def authenticate(settings: Settings = Depends(get_settings), authorization
|
|||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
summary="Completion"
|
||||
"/v1/completions", summary="Completion", dependencies=[Depends(authenticate)]
|
||||
)
|
||||
@router.post(
|
||||
"/v1/engines/copilot-codex/completions",
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(authenticate)],
|
||||
)
|
||||
@router.post("/v1/engines/copilot-codex/completions", include_in_schema=False)
|
||||
async def create_completion(
|
||||
request: Request,
|
||||
body: CreateCompletionRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
authenticated: str = Depends(authenticate),
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> llama_cpp.Completion:
|
||||
if isinstance(body.prompt, list):
|
||||
assert len(body.prompt) <= 1
|
||||
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
|
||||
|
||||
llama = llama_proxy(
|
||||
body.model
|
||||
if request.url.path != "/v1/engines/copilot-codex/completions"
|
||||
else "copilot-codex"
|
||||
)
|
||||
|
||||
exclude = {
|
||||
"n",
|
||||
"best_of",
|
||||
|
@ -749,124 +266,26 @@ async def create_completion(
|
|||
return iterator_or_completion
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
input: Union[str, List[str]] = Field(description="The input to embed.")
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"input": "The food was delicious and the waiter...",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
summary="Embedding"
|
||||
"/v1/embeddings", summary="Embedding", dependencies=[Depends(authenticate)]
|
||||
)
|
||||
async def create_embedding(
|
||||
request: CreateEmbeddingRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
authenticated: str = Depends(authenticate),
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
):
|
||||
return await run_in_threadpool(
|
||||
llama.create_embedding, **request.model_dump(exclude={"user"})
|
||||
llama_proxy(request.model).create_embedding,
|
||||
**request.model_dump(exclude={"user"}),
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "function"] = Field(
|
||||
default="user", description="The role of the message."
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
default="", description="The content of the message."
|
||||
)
|
||||
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
|
||||
default=[], description="A list of messages to generate completions for."
|
||||
)
|
||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||
default=None,
|
||||
description="A list of functions to apply to the generated completions.",
|
||||
)
|
||||
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
|
||||
default=None,
|
||||
description="A function to apply to the generated completions.",
|
||||
)
|
||||
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
|
||||
default=None,
|
||||
description="A list of tools to apply to the generated completions.",
|
||||
)
|
||||
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
|
||||
default=None,
|
||||
description="A tool to apply to the generated completions.",
|
||||
) # TODO: verify
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate. Defaults to inf",
|
||||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
user: Optional[str] = Field(None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"messages": [
|
||||
ChatCompletionRequestMessage(
|
||||
role="system", content="You are a helpful assistant."
|
||||
).model_dump(),
|
||||
ChatCompletionRequestMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
).model_dump(),
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
summary="Chat"
|
||||
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)]
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: Request,
|
||||
body: CreateChatCompletionRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
settings: Settings = Depends(get_settings),
|
||||
authenticated: str = Depends(authenticate),
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> llama_cpp.ChatCompletion:
|
||||
exclude = {
|
||||
"n",
|
||||
|
@ -874,7 +293,7 @@ async def create_chat_completion(
|
|||
"user",
|
||||
}
|
||||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
llama = llama_proxy(body.model)
|
||||
if body.logit_bias is not None:
|
||||
kwargs["logit_bias"] = (
|
||||
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||
|
@ -913,34 +332,19 @@ async def create_chat_completion(
|
|||
return iterator_or_completion
|
||||
|
||||
|
||||
class ModelData(TypedDict):
|
||||
id: str
|
||||
object: Literal["model"]
|
||||
owned_by: str
|
||||
permissions: List[str]
|
||||
|
||||
|
||||
class ModelList(TypedDict):
|
||||
object: Literal["list"]
|
||||
data: List[ModelData]
|
||||
|
||||
|
||||
@router.get("/v1/models", summary="Models")
|
||||
@router.get("/v1/models", summary="Models", dependencies=[Depends(authenticate)])
|
||||
async def get_models(
|
||||
settings: Settings = Depends(get_settings),
|
||||
authenticated: str = Depends(authenticate),
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> ModelList:
|
||||
assert llama is not None
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": settings.model_alias
|
||||
if settings.model_alias is not None
|
||||
else llama.model_path,
|
||||
"id": model_alias,
|
||||
"object": "model",
|
||||
"owned_by": "me",
|
||||
"permissions": [],
|
||||
}
|
||||
for model_alias in llama_proxy
|
||||
],
|
||||
}
|
||||
|
|
97
llama_cpp/server/cli.py
Normal file
97
llama_cpp/server/cli.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
from typing import List, Literal, Union, Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
|
||||
if getattr(annotation, "__origin__", None) is Literal:
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
return type(annotation.__args__[0]) # type: ignore
|
||||
elif getattr(annotation, "__origin__", None) is Union:
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
non_optional_args: List[Type[Any]] = [
|
||||
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
|
||||
]
|
||||
if non_optional_args:
|
||||
return _get_base_type(non_optional_args[0])
|
||||
elif (
|
||||
getattr(annotation, "__origin__", None) is list
|
||||
or getattr(annotation, "__origin__", None) is List
|
||||
):
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
return _get_base_type(annotation.__args__[0]) # type: ignore
|
||||
return annotation
|
||||
|
||||
|
||||
def _contains_list_type(annotation: Type[Any] | None) -> bool:
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
|
||||
if origin is list or origin is List:
|
||||
return True
|
||||
elif origin in (Literal, Union):
|
||||
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
|
||||
if isinstance(arg, bytes):
|
||||
arg = arg.decode("utf-8")
|
||||
|
||||
true_values = {"1", "on", "t", "true", "y", "yes"}
|
||||
false_values = {"0", "off", "f", "false", "n", "no"}
|
||||
|
||||
arg_str = str(arg).lower().strip()
|
||||
|
||||
if arg_str in true_values:
|
||||
return True
|
||||
elif arg_str in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean argument: {arg}")
|
||||
|
||||
|
||||
def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
|
||||
"""Add arguments from a pydantic model to an argparse parser."""
|
||||
|
||||
for name, field in model.model_fields.items():
|
||||
description = field.description
|
||||
if field.default and description and not field.is_required():
|
||||
description += f" (default: {field.default})"
|
||||
base_type = (
|
||||
_get_base_type(field.annotation) if field.annotation is not None else str
|
||||
)
|
||||
list_type = _contains_list_type(field.annotation)
|
||||
if base_type is not bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*" if list_type else None,
|
||||
type=base_type,
|
||||
help=description,
|
||||
)
|
||||
if base_type is bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=_parse_bool_arg,
|
||||
help=f"{description}",
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=type[BaseModel])
|
||||
|
||||
|
||||
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
|
||||
"""Parse a pydantic model from an argparse namespace."""
|
||||
return model(
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if v is not None and k in model.model_fields
|
||||
}
|
||||
)
|
210
llama_cpp/server/errors.py
Normal file
210
llama_cpp/server/errors.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import time
|
||||
from re import compile, Match, Pattern
|
||||
from typing import Callable, Coroutine, Optional, Tuple, Union, Dict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
from fastapi import (
|
||||
Request,
|
||||
Response,
|
||||
HTTPException,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from llama_cpp.server.types import (
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
CreateChatCompletionRequest,
|
||||
)
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
"""OpenAI style error response"""
|
||||
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str]
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class ErrorResponseFormatters:
|
||||
"""Collection of formatters for error responses.
|
||||
|
||||
Args:
|
||||
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||
Request body
|
||||
match (Match[str]): Match object from regex pattern
|
||||
|
||||
Returns:
|
||||
Tuple[int, ErrorResponse]: Status code and error response
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def context_length_exceeded(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for context length exceeded error"""
|
||||
|
||||
context_window = int(match.group(2))
|
||||
prompt_tokens = int(match.group(1))
|
||||
completion_tokens = request.max_tokens
|
||||
if hasattr(request, "messages"):
|
||||
# Chat completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens. "
|
||||
"However, you requested {} tokens "
|
||||
"({} in the messages, {} in the completion). "
|
||||
"Please reduce the length of the messages or completion."
|
||||
)
|
||||
else:
|
||||
# Text completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens, "
|
||||
"however you requested {} tokens "
|
||||
"({} in your prompt; {} for the completion). "
|
||||
"Please reduce your prompt; or completion length."
|
||||
)
|
||||
return 400, ErrorResponse(
|
||||
message=message.format(
|
||||
context_window,
|
||||
completion_tokens + prompt_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
), # type: ignore
|
||||
type="invalid_request_error",
|
||||
param="messages",
|
||||
code="context_length_exceeded",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def model_not_found(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for model_not_found error"""
|
||||
|
||||
model_path = str(match.group(1))
|
||||
message = f"The model `{model_path}` does not exist"
|
||||
return 400, ErrorResponse(
|
||||
message=message,
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code="model_not_found",
|
||||
)
|
||||
|
||||
|
||||
class RouteErrorHandler(APIRoute):
|
||||
"""Custom APIRoute that handles application errors and exceptions"""
|
||||
|
||||
# key: regex pattern for original error message from llama_cpp
|
||||
# value: formatter function
|
||||
pattern_and_formatters: Dict[
|
||||
"Pattern[str]",
|
||||
Callable[
|
||||
[
|
||||
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
"Match[str]",
|
||||
],
|
||||
Tuple[int, ErrorResponse],
|
||||
],
|
||||
] = {
|
||||
compile(
|
||||
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||
): ErrorResponseFormatters.context_length_exceeded,
|
||||
compile(
|
||||
r"Model path does not exist: (.+)"
|
||||
): ErrorResponseFormatters.model_not_found,
|
||||
}
|
||||
|
||||
def error_message_wrapper(
|
||||
self,
|
||||
error: Exception,
|
||||
body: Optional[
|
||||
Union[
|
||||
"CreateChatCompletionRequest",
|
||||
"CreateCompletionRequest",
|
||||
"CreateEmbeddingRequest",
|
||||
]
|
||||
] = None,
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Wraps error message in OpenAI style error response"""
|
||||
print(f"Exception: {str(error)}", file=sys.stderr)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
if body is not None and isinstance(
|
||||
body,
|
||||
(
|
||||
CreateCompletionRequest,
|
||||
CreateChatCompletionRequest,
|
||||
),
|
||||
):
|
||||
# When text completion or chat completion
|
||||
for pattern, callback in self.pattern_and_formatters.items():
|
||||
match = pattern.search(str(error))
|
||||
if match is not None:
|
||||
return callback(body, match)
|
||||
|
||||
# Wrap other errors as internal server error
|
||||
return 500, ErrorResponse(
|
||||
message=str(error),
|
||||
type="internal_server_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
def get_route_handler(
|
||||
self,
|
||||
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||
"""Defines custom route handler that catches exceptions and formats
|
||||
in OpenAI style error response"""
|
||||
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
try:
|
||||
start_sec = time.perf_counter()
|
||||
response = await original_route_handler(request)
|
||||
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
||||
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
||||
return response
|
||||
except HTTPException as unauthorized:
|
||||
# api key check failed
|
||||
raise unauthorized
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
if "messages" in json_body:
|
||||
# Chat completion
|
||||
body: Optional[
|
||||
Union[
|
||||
CreateChatCompletionRequest,
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
]
|
||||
] = CreateChatCompletionRequest(**json_body)
|
||||
elif "prompt" in json_body:
|
||||
# Text completion
|
||||
body = CreateCompletionRequest(**json_body)
|
||||
else:
|
||||
# Embedding
|
||||
body = CreateEmbeddingRequest(**json_body)
|
||||
except Exception:
|
||||
# Invalid request body
|
||||
body = None
|
||||
|
||||
# Get proper error message from the exception
|
||||
(
|
||||
status_code,
|
||||
error_message,
|
||||
) = self.error_message_wrapper(error=exc, body=body)
|
||||
return JSONResponse(
|
||||
{"error": error_message},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return custom_route_handler
|
||||
|
126
llama_cpp/server/model.py
Normal file
126
llama_cpp/server/model.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from llama_cpp.server.settings import ModelSettings
|
||||
|
||||
|
||||
class LlamaProxy:
|
||||
def __init__(self, models: List[ModelSettings]) -> None:
|
||||
assert len(models) > 0, "No models provided!"
|
||||
|
||||
self._model_settings_dict: dict[str, ModelSettings] = {}
|
||||
for model in models:
|
||||
if not model.model_alias:
|
||||
model.model_alias = model.model
|
||||
self._model_settings_dict[model.model_alias] = model
|
||||
|
||||
self._current_model: Optional[llama_cpp.Llama] = None
|
||||
self._current_model_alias: Optional[str] = None
|
||||
|
||||
self._default_model_settings: ModelSettings = models[0]
|
||||
self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore
|
||||
|
||||
# Load default model
|
||||
self._current_model = self.load_llama_from_model_settings(
|
||||
self._default_model_settings
|
||||
)
|
||||
self._current_model_alias = self._default_model_alias
|
||||
|
||||
def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
|
||||
if model is None:
|
||||
model = self._default_model_alias
|
||||
|
||||
if model not in self._model_settings_dict:
|
||||
model = self._default_model_alias
|
||||
|
||||
if model == self._current_model_alias:
|
||||
if self._current_model is not None:
|
||||
return self._current_model
|
||||
|
||||
self._current_model = None
|
||||
|
||||
settings = self._model_settings_dict[model]
|
||||
self._current_model = self.load_llama_from_model_settings(settings)
|
||||
self._current_model_alias = model
|
||||
return self._current_model
|
||||
|
||||
def __getitem__(self, model: str):
|
||||
return self._model_settings_dict[model].model_dump()
|
||||
|
||||
def __setitem__(self, model: str, settings: Union[ModelSettings, str, bytes]):
|
||||
if isinstance(settings, (bytes, str)):
|
||||
settings = ModelSettings.model_validate_json(settings)
|
||||
self._model_settings_dict[model] = settings
|
||||
|
||||
def __iter__(self):
|
||||
for model in self._model_settings_dict:
|
||||
yield model
|
||||
|
||||
def free(self):
|
||||
if self._current_model:
|
||||
del self._current_model
|
||||
|
||||
@staticmethod
|
||||
def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
|
||||
chat_handler = None
|
||||
if settings.chat_format == "llava-1-5":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
|
||||
_model = llama_cpp.Llama(
|
||||
model_path=settings.model,
|
||||
# Model Params
|
||||
n_gpu_layers=settings.n_gpu_layers,
|
||||
main_gpu=settings.main_gpu,
|
||||
tensor_split=settings.tensor_split,
|
||||
vocab_only=settings.vocab_only,
|
||||
use_mmap=settings.use_mmap,
|
||||
use_mlock=settings.use_mlock,
|
||||
# Context Params
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
n_batch=settings.n_batch,
|
||||
n_threads=settings.n_threads,
|
||||
n_threads_batch=settings.n_threads_batch,
|
||||
rope_scaling_type=settings.rope_scaling_type,
|
||||
rope_freq_base=settings.rope_freq_base,
|
||||
rope_freq_scale=settings.rope_freq_scale,
|
||||
yarn_ext_factor=settings.yarn_ext_factor,
|
||||
yarn_attn_factor=settings.yarn_attn_factor,
|
||||
yarn_beta_fast=settings.yarn_beta_fast,
|
||||
yarn_beta_slow=settings.yarn_beta_slow,
|
||||
yarn_orig_ctx=settings.yarn_orig_ctx,
|
||||
mul_mat_q=settings.mul_mat_q,
|
||||
logits_all=settings.logits_all,
|
||||
embedding=settings.embedding,
|
||||
offload_kqv=settings.offload_kqv,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
lora_base=settings.lora_base,
|
||||
lora_path=settings.lora_path,
|
||||
# Backend Params
|
||||
numa=settings.numa,
|
||||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
if settings.cache:
|
||||
if settings.cache_type == "disk":
|
||||
if settings.verbose:
|
||||
print(f"Using disk cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
|
||||
else:
|
||||
if settings.verbose:
|
||||
print(f"Using ram cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||
_model.set_cache(cache)
|
||||
return _model
|
||||
|
161
llama_cpp/server/settings.py
Normal file
161
llama_cpp/server/settings.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from typing import Optional, List, Literal
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
import llama_cpp
|
||||
|
||||
# Disable warning for model and model_alias settings
|
||||
BaseSettings.model_config["protected_namespaces"] = ()
|
||||
|
||||
|
||||
class ModelSettings(BaseSettings):
|
||||
model: str = Field(
|
||||
description="The path to the model to use for generating completions."
|
||||
)
|
||||
model_alias: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
)
|
||||
# Model Params
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
|
||||
)
|
||||
main_gpu: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Main GPU to use.",
|
||||
)
|
||||
tensor_split: Optional[List[float]] = Field(
|
||||
default=None,
|
||||
description="Split layers across multiple GPUs in proportion.",
|
||||
)
|
||||
vocab_only: bool = Field(
|
||||
default=False, description="Whether to only return the vocabulary."
|
||||
)
|
||||
use_mmap: bool = Field(
|
||||
default=llama_cpp.llama_mmap_supported(),
|
||||
description="Use mmap.",
|
||||
)
|
||||
use_mlock: bool = Field(
|
||||
default=llama_cpp.llama_mlock_supported(),
|
||||
description="Use mlock.",
|
||||
)
|
||||
# Context Params
|
||||
seed: int = Field(
|
||||
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
|
||||
)
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
n_threads: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=1,
|
||||
description="The number of threads to use.",
|
||||
)
|
||||
n_threads_batch: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=0,
|
||||
description="The number of threads to use when batch processing.",
|
||||
)
|
||||
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
|
||||
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
|
||||
rope_freq_scale: float = Field(
|
||||
default=0.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
yarn_ext_factor: float = Field(default=-1.0)
|
||||
yarn_attn_factor: float = Field(default=1.0)
|
||||
yarn_beta_fast: float = Field(default=32.0)
|
||||
yarn_beta_slow: float = Field(default=1.0)
|
||||
yarn_orig_ctx: int = Field(default=0)
|
||||
mul_mat_q: bool = Field(
|
||||
default=True, description="if true, use experimental mul_mat_q kernels"
|
||||
)
|
||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||
embedding: bool = Field(default=True, description="Whether to use embeddings.")
|
||||
offload_kqv: bool = Field(
|
||||
default=False, description="Whether to offload kqv to the GPU."
|
||||
)
|
||||
# Sampling Params
|
||||
last_n_tokens_size: int = Field(
|
||||
default=64,
|
||||
ge=0,
|
||||
description="Last n tokens to keep for repeat penalty calculation.",
|
||||
)
|
||||
# LoRA Params
|
||||
lora_base: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
|
||||
)
|
||||
lora_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a LoRA file to apply to the model.",
|
||||
)
|
||||
# Backend Params
|
||||
numa: bool = Field(
|
||||
default=False,
|
||||
description="Enable NUMA support.",
|
||||
)
|
||||
# Chat Format Params
|
||||
chat_format: str = Field(
|
||||
default="llama-2",
|
||||
description="Chat format to use.",
|
||||
)
|
||||
clip_model_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a CLIP model to use for multi-modal chat completion.",
|
||||
)
|
||||
# Cache Params
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
description="Use a cache to reduce processing times for evaluated prompts.",
|
||||
)
|
||||
cache_type: Literal["ram", "disk"] = Field(
|
||||
default="ram",
|
||||
description="The type of cache to use. Only used if cache is True.",
|
||||
)
|
||||
cache_size: int = Field(
|
||||
default=2 << 30,
|
||||
description="The size of the cache in bytes. Only used if cache is True.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
)
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
# Uvicorn Settings
|
||||
host: str = Field(default="localhost", description="Listen address")
|
||||
port: int = Field(default=8000, description="Listen port")
|
||||
ssl_keyfile: Optional[str] = Field(
|
||||
default=None, description="SSL key file for HTTPS"
|
||||
)
|
||||
ssl_certfile: Optional[str] = Field(
|
||||
default=None, description="SSL certificate file for HTTPS"
|
||||
)
|
||||
# FastAPI Settings
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authentication. If set all requests need to be authenticated.",
|
||||
)
|
||||
interrupt_requests: bool = Field(
|
||||
default=True,
|
||||
description="Whether to interrupt requests when a new request is received.",
|
||||
)
|
||||
|
||||
|
||||
class Settings(ServerSettings, ModelSettings):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileSettings(ServerSettings):
|
||||
models: List[ModelSettings] = Field(
|
||||
default=[], description="Model configs, overwrites default config"
|
||||
)
|
264
llama_cpp/server/types.py
Normal file
264
llama_cpp/server/types.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import llama_cpp
|
||||
|
||||
|
||||
model_field = Field(
|
||||
description="The model to use for generating completions.", default=None
|
||||
)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||
)
|
||||
|
||||
temperature_field = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Adjust the randomness of the generated text.\n\n"
|
||||
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
|
||||
)
|
||||
|
||||
top_p_field = Field(
|
||||
default=0.95,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
|
||||
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
||||
)
|
||||
|
||||
min_p_field = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Sets a minimum base probability threshold for token selection.\n\n"
|
||||
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
|
||||
)
|
||||
|
||||
stop_field = Field(
|
||||
default=None,
|
||||
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
||||
)
|
||||
|
||||
stream_field = Field(
|
||||
default=False,
|
||||
description="Whether to stream the results as they are generated. Useful for chatbots.",
|
||||
)
|
||||
|
||||
top_k_field = Field(
|
||||
default=40,
|
||||
ge=0,
|
||||
description="Limit the next token selection to the K most probable tokens.\n\n"
|
||||
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
|
||||
)
|
||||
|
||||
repeat_penalty_field = Field(
|
||||
default=1.1,
|
||||
ge=0.0,
|
||||
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
|
||||
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
|
||||
)
|
||||
|
||||
presence_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||||
)
|
||||
|
||||
frequency_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||
)
|
||||
|
||||
mirostat_mode_field = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=2,
|
||||
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
|
||||
)
|
||||
|
||||
mirostat_tau_field = Field(
|
||||
default=5.0,
|
||||
ge=0.0,
|
||||
le=10.0,
|
||||
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
|
||||
)
|
||||
|
||||
mirostat_eta_field = Field(
|
||||
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
|
||||
)
|
||||
|
||||
grammar = Field(
|
||||
default=None,
|
||||
description="A CBNF grammar (as string) to be used for formatting the model's output.",
|
||||
)
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]] = Field(
|
||||
default="", description="The prompt to generate completions for."
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
|
||||
)
|
||||
max_tokens: int = max_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="The number of logprobs to generate. If None, no logprobs are generated.",
|
||||
)
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logprobs: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
best_of: Optional[int] = 1
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
||||
"stop": ["\n", "###"],
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
input: Union[str, List[str]] = Field(description="The input to embed.")
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"input": "The food was delicious and the waiter...",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "function"] = Field(
|
||||
default="user", description="The role of the message."
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
default="", description="The content of the message."
|
||||
)
|
||||
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
|
||||
default=[], description="A list of messages to generate completions for."
|
||||
)
|
||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||
default=None,
|
||||
description="A list of functions to apply to the generated completions.",
|
||||
)
|
||||
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
|
||||
default=None,
|
||||
description="A function to apply to the generated completions.",
|
||||
)
|
||||
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
|
||||
default=None,
|
||||
description="A list of tools to apply to the generated completions.",
|
||||
)
|
||||
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
|
||||
default=None,
|
||||
description="A tool to apply to the generated completions.",
|
||||
) # TODO: verify
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate. Defaults to inf",
|
||||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
user: Optional[str] = Field(None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"messages": [
|
||||
ChatCompletionRequestMessage(
|
||||
role="system", content="You are a helpful assistant."
|
||||
).model_dump(),
|
||||
ChatCompletionRequestMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
).model_dump(),
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelData(TypedDict):
|
||||
id: str
|
||||
object: Literal["model"]
|
||||
owned_by: str
|
||||
permissions: List[str]
|
||||
|
||||
|
||||
class ModelList(TypedDict):
|
||||
object: Literal["list"]
|
||||
data: List[ModelData]
|
Loading…
Add table
Reference in a new issue