2023-09-30 19:13:36 -04:00
import sys
2023-04-28 22:43:37 -07:00
import json
2023-10-10 15:56:04 -04:00
import traceback
2023-05-07 03:03:57 -04:00
import multiprocessing
2023-09-25 13:55:58 -04:00
import time
2023-07-16 14:57:39 +09:00
from re import compile , Match , Pattern
2023-04-28 22:43:37 -07:00
from threading import Lock
2023-05-27 09:12:58 -04:00
from functools import partial
2023-07-20 18:52:10 -04:00
from typing import Callable , Coroutine , Iterator , List , Optional , Tuple , Union , Dict
2023-05-07 03:03:57 -04:00
from typing_extensions import TypedDict , Literal
2023-04-28 22:43:37 -07:00
import llama_cpp
2023-05-27 09:12:58 -04:00
import anyio
from anyio . streams . memory import MemoryObjectSendStream
from starlette . concurrency import run_in_threadpool , iterate_in_threadpool
2023-07-16 14:57:39 +09:00
from fastapi import Depends , FastAPI , APIRouter , Request , Response
2023-09-13 23:18:31 +03:00
from fastapi . middleware import Middleware
2023-04-28 22:43:37 -07:00
from fastapi . middleware . cors import CORSMiddleware
2023-07-16 14:57:39 +09:00
from fastapi . responses import JSONResponse
from fastapi . routing import APIRoute
2023-07-07 21:38:46 -04:00
from pydantic import BaseModel , Field
from pydantic_settings import BaseSettings
2023-04-28 22:43:37 -07:00
from sse_starlette . sse import EventSourceResponse
2023-09-13 23:18:31 +03:00
from starlette_context import plugins
from starlette_context . middleware import RawContextMiddleware
2023-04-28 22:43:37 -07:00
2023-07-18 19:27:41 -04:00
import numpy as np
import numpy . typing as npt
2023-04-28 22:43:37 -07:00
2023-09-29 19:59:12 -04:00
# Disable warning for model and model_alias settings
2023-09-13 23:01:34 -04:00
BaseSettings . model_config [ ' protected_namespaces ' ] = ( )
2023-04-28 22:43:37 -07:00
class Settings ( BaseSettings ) :
2023-05-07 02:52:20 -04:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
2023-05-16 17:22:00 -04:00
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
2023-11-02 13:40:20 -04:00
# Model Params
2023-05-14 00:04:22 -04:00
n_gpu_layers : int = Field (
default = 0 ,
2023-10-18 17:25:25 -05:00
ge = - 1 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU. " ,
2023-05-14 00:04:22 -04:00
)
2023-09-13 21:23:13 -04:00
main_gpu : int = Field (
default = 0 ,
ge = 0 ,
description = " Main GPU to use. " ,
)
2023-07-14 16:52:48 -04:00
tensor_split : Optional [ List [ float ] ] = Field (
2023-07-07 19:22:10 +10:00
default = None ,
description = " Split layers across multiple GPUs in proportion. " ,
)
2023-09-13 21:23:13 -04:00
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
2023-05-07 02:52:20 -04:00
)
use_mmap : bool = Field (
2023-05-07 03:04:22 -04:00
default = llama_cpp . llama_mmap_supported ( ) ,
2023-05-07 02:52:20 -04:00
description = " Use mmap. " ,
)
2023-09-13 21:23:13 -04:00
use_mlock : bool = Field (
default = llama_cpp . llama_mlock_supported ( ) ,
description = " Use mlock. " ,
)
2023-11-02 13:40:20 -04:00
# Context Params
seed : int = Field ( default = llama_cpp . LLAMA_DEFAULT_SEED , description = " Random seed. -1 for random. " )
n_ctx : int = Field ( default = 2048 , ge = 1 , description = " The context size. " )
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
2023-09-13 21:23:13 -04:00
n_threads : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 1 ,
description = " The number of threads to use. " ,
2023-06-14 22:13:42 -04:00
)
2023-11-02 13:40:20 -04:00
n_threads_batch : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 0 ,
description = " The number of threads to use when batch processing. " ,
)
rope_scaling_type : int = Field (
default = llama_cpp . LLAMA_ROPE_SCALING_UNSPECIFIED
)
rope_freq_base : float = Field (
default = 0.0 , description = " RoPE base frequency "
)
rope_freq_scale : float = Field (
default = 0.0 , description = " RoPE frequency scaling factor "
)
yarn_ext_factor : float = Field (
2023-11-03 11:34:50 -04:00
default = - 1.0
2023-11-02 13:40:20 -04:00
)
yarn_attn_factor : float = Field (
default = 1.0
)
yarn_beta_fast : float = Field (
default = 32.0
)
yarn_beta_slow : float = Field (
default = 1.0
)
yarn_orig_ctx : int = Field (
default = 0
)
mul_mat_q : bool = Field (
default = True , description = " if true, use experimental mul_mat_q kernels "
)
f16_kv : bool = Field ( default = True , description = " Whether to use f16 key/value. " )
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
# Sampling Params
2023-05-07 02:52:20 -04:00
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
2023-11-02 13:40:20 -04:00
# LoRA Params
2023-09-13 21:23:13 -04:00
lora_base : Optional [ str ] = Field (
default = None ,
description = " Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. "
)
lora_path : Optional [ str ] = Field (
default = None ,
description = " Path to a LoRA file to apply to the model. " ,
)
2023-11-02 13:40:20 -04:00
# Backend Params
2023-09-13 23:00:43 -04:00
numa : bool = Field (
default = False ,
description = " Enable NUMA support. " ,
)
2023-11-02 13:40:20 -04:00
# Chat Format Params
2023-09-29 19:59:12 -04:00
chat_format : str = Field (
default = " llama-2 " ,
description = " Chat format to use. " ,
)
2023-11-08 04:48:51 +01:00
clip_model_path : Optional [ str ] = Field (
default = None ,
description = " Path to a CLIP model to use for multi-modal chat completion. " ,
)
2023-11-02 13:40:20 -04:00
# Cache Params
2023-05-07 02:52:20 -04:00
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
2023-06-08 13:19:23 -04:00
cache_type : Literal [ " ram " , " disk " ] = Field (
default = " ram " ,
description = " The type of cache to use. Only used if cache is True. " ,
)
2023-05-07 19:33:17 -04:00
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2023-11-02 13:40:20 -04:00
# Misc
2023-05-07 05:09:10 -04:00
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
2023-11-02 13:40:20 -04:00
# Server Params
2023-07-13 23:25:12 -04:00
host : str = Field ( default = " localhost " , description = " Listen address " )
port : int = Field ( default = 8000 , description = " Listen port " )
2023-07-07 03:37:23 -04:00
interrupt_requests : bool = Field (
default = True ,
description = " Whether to interrupt requests when a new request is received. " ,
)
2023-04-28 22:43:37 -07:00
2023-07-16 14:57:39 +09:00
class ErrorResponse ( TypedDict ) :
""" OpenAI style error response """
message : str
type : str
param : Optional [ str ]
code : Optional [ str ]
class ErrorResponseFormatters :
""" Collection of formatters for error responses.
Args :
request ( Union [ CreateCompletionRequest , CreateChatCompletionRequest ] ) :
Request body
match ( Match [ str ] ) : Match object from regex pattern
Returns :
2023-07-20 18:52:10 -04:00
Tuple [ int , ErrorResponse ] : Status code and error response
2023-07-16 14:57:39 +09:00
"""
@staticmethod
def context_length_exceeded (
2023-09-13 21:23:23 -04:00
request : Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
match , # type: Match[str] # type: ignore
2023-07-20 18:52:10 -04:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 14:57:39 +09:00
""" Formatter for context length exceeded error """
context_window = int ( match . group ( 2 ) )
prompt_tokens = int ( match . group ( 1 ) )
completion_tokens = request . max_tokens
if hasattr ( request , " messages " ) :
# Chat completion
message = (
" This model ' s maximum context length is {} tokens. "
" However, you requested {} tokens "
" ( {} in the messages, {} in the completion). "
" Please reduce the length of the messages or completion. "
)
else :
# Text completion
message = (
" This model ' s maximum context length is {} tokens, "
" however you requested {} tokens "
" ( {} in your prompt; {} for the completion). "
" Please reduce your prompt; or completion length. "
)
return 400 , ErrorResponse (
message = message . format (
context_window ,
completion_tokens + prompt_tokens ,
prompt_tokens ,
completion_tokens ,
) ,
type = " invalid_request_error " ,
param = " messages " ,
code = " context_length_exceeded " ,
)
@staticmethod
def model_not_found (
2023-09-13 21:23:23 -04:00
request : Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
match , # type: Match[str] # type: ignore
2023-07-20 18:52:10 -04:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 14:57:39 +09:00
""" Formatter for model_not_found error """
model_path = str ( match . group ( 1 ) )
message = f " The model ` { model_path } ` does not exist "
return 400 , ErrorResponse (
message = message ,
type = " invalid_request_error " ,
param = None ,
code = " model_not_found " ,
)
class RouteErrorHandler ( APIRoute ) :
""" Custom APIRoute that handles application errors and exceptions """
# key: regex pattern for original error message from llama_cpp
# value: formatter function
2023-07-20 18:52:10 -04:00
pattern_and_formatters : Dict [
2023-07-16 14:57:39 +09:00
" Pattern " ,
Callable [
[
Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
2023-07-20 18:52:10 -04:00
" Match[str] " ,
2023-07-16 14:57:39 +09:00
] ,
2023-07-20 18:52:10 -04:00
Tuple [ int , ErrorResponse ] ,
2023-07-16 14:57:39 +09:00
] ,
] = {
compile (
r " Requested tokens \ (( \ d+) \ ) exceed context window of ( \ d+) "
) : ErrorResponseFormatters . context_length_exceeded ,
compile (
r " Model path does not exist: (.+) "
) : ErrorResponseFormatters . model_not_found ,
}
def error_message_wrapper (
self ,
error : Exception ,
body : Optional [
Union [
" CreateChatCompletionRequest " ,
" CreateCompletionRequest " ,
" CreateEmbeddingRequest " ,
]
] = None ,
2023-07-20 18:52:10 -04:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 14:57:39 +09:00
""" Wraps error message in OpenAI style error response """
2023-09-30 19:13:36 -04:00
print ( f " Exception: { str ( error ) } " , file = sys . stderr )
2023-10-10 15:56:04 -04:00
traceback . print_exc ( file = sys . stderr )
2023-07-16 14:57:39 +09:00
if body is not None and isinstance (
body ,
(
CreateCompletionRequest ,
CreateChatCompletionRequest ,
) ,
) :
# When text completion or chat completion
for pattern , callback in self . pattern_and_formatters . items ( ) :
match = pattern . search ( str ( error ) )
if match is not None :
return callback ( body , match )
# Wrap other errors as internal server error
return 500 , ErrorResponse (
message = str ( error ) ,
type = " internal_server_error " ,
param = None ,
code = None ,
)
def get_route_handler (
self ,
) - > Callable [ [ Request ] , Coroutine [ None , None , Response ] ] :
""" Defines custom route handler that catches exceptions and formats
in OpenAI style error response """
original_route_handler = super ( ) . get_route_handler ( )
async def custom_route_handler ( request : Request ) - > Response :
try :
2023-09-25 13:55:58 -04:00
start_sec = time . perf_counter ( )
response = await original_route_handler ( request )
elapsed_time_ms = int ( ( time . perf_counter ( ) - start_sec ) * 1000 )
response . headers [ " openai-processing-ms " ] = f " { elapsed_time_ms } "
return response
2023-07-16 14:57:39 +09:00
except Exception as exc :
json_body = await request . json ( )
try :
if " messages " in json_body :
# Chat completion
body : Optional [
Union [
CreateChatCompletionRequest ,
CreateCompletionRequest ,
CreateEmbeddingRequest ,
]
] = CreateChatCompletionRequest ( * * json_body )
elif " prompt " in json_body :
# Text completion
body = CreateCompletionRequest ( * * json_body )
else :
# Embedding
body = CreateEmbeddingRequest ( * * json_body )
except Exception :
# Invalid request body
body = None
# Get proper error message from the exception
(
status_code ,
error_message ,
) = self . error_message_wrapper ( error = exc , body = body )
return JSONResponse (
{ " error " : error_message } ,
status_code = status_code ,
)
return custom_route_handler
router = APIRouter ( route_class = RouteErrorHandler )
2023-05-01 22:38:46 -04:00
2023-05-16 17:22:00 -04:00
settings : Optional [ Settings ] = None
2023-05-01 22:38:46 -04:00
llama : Optional [ llama_cpp . Llama ] = None
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
def create_app ( settings : Optional [ Settings ] = None ) :
2023-04-28 23:47:36 -07:00
if settings is None :
settings = Settings ( )
2023-09-13 23:18:31 +03:00
middleware = [
2023-09-13 21:23:23 -04:00
Middleware ( RawContextMiddleware , plugins = ( plugins . RequestIdPlugin ( ) , ) )
2023-09-13 23:18:31 +03:00
]
2023-05-01 22:38:46 -04:00
app = FastAPI (
2023-09-13 23:18:31 +03:00
middleware = middleware ,
2023-05-01 22:38:46 -04:00
title = " 🦙 llama.cpp Python API " ,
version = " 0.0.1 " ,
)
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
app . include_router ( router )
2023-04-28 23:47:36 -07:00
global llama
2023-11-08 04:48:51 +01:00
##
chat_handler = None
if settings . chat_format == " llava-1-5 " :
assert settings . clip_model_path is not None
2023-11-08 11:05:45 -05:00
chat_handler = llama_cpp . llama_chat_format . Llava15ChatHandler ( clip_model_path = settings . clip_model_path , verbose = settings . verbose )
2023-11-08 04:48:51 +01:00
##
2023-04-28 23:47:36 -07:00
llama = llama_cpp . Llama (
2023-05-01 22:38:46 -04:00
model_path = settings . model ,
2023-11-02 13:40:20 -04:00
# Model Params
2023-05-14 00:04:22 -04:00
n_gpu_layers = settings . n_gpu_layers ,
2023-09-13 21:23:13 -04:00
main_gpu = settings . main_gpu ,
2023-07-07 19:22:10 +10:00
tensor_split = settings . tensor_split ,
2023-11-02 13:40:20 -04:00
vocab_only = settings . vocab_only ,
use_mmap = settings . use_mmap ,
use_mlock = settings . use_mlock ,
# Context Params
seed = settings . seed ,
n_ctx = settings . n_ctx ,
n_batch = settings . n_batch ,
n_threads = settings . n_threads ,
n_threads_batch = settings . n_threads_batch ,
rope_scaling_type = settings . rope_scaling_type ,
2023-07-18 16:34:36 +08:00
rope_freq_base = settings . rope_freq_base ,
rope_freq_scale = settings . rope_freq_scale ,
2023-11-02 13:40:20 -04:00
yarn_ext_factor = settings . yarn_ext_factor ,
yarn_attn_factor = settings . yarn_attn_factor ,
yarn_beta_fast = settings . yarn_beta_fast ,
yarn_beta_slow = settings . yarn_beta_slow ,
yarn_orig_ctx = settings . yarn_orig_ctx ,
2023-09-13 21:23:13 -04:00
mul_mat_q = settings . mul_mat_q ,
2023-04-28 23:47:36 -07:00
f16_kv = settings . f16_kv ,
2023-09-13 21:23:13 -04:00
logits_all = settings . logits_all ,
2023-04-28 23:47:36 -07:00
embedding = settings . embedding ,
2023-11-02 13:40:20 -04:00
# Sampling Params
2023-04-28 23:47:36 -07:00
last_n_tokens_size = settings . last_n_tokens_size ,
2023-11-02 13:40:20 -04:00
# LoRA Params
2023-09-13 21:23:13 -04:00
lora_base = settings . lora_base ,
lora_path = settings . lora_path ,
2023-11-02 13:40:20 -04:00
# Backend Params
2023-09-29 19:59:12 -04:00
numa = settings . numa ,
2023-11-02 13:40:20 -04:00
# Chat Format Params
2023-09-29 19:59:12 -04:00
chat_format = settings . chat_format ,
2023-11-08 04:48:51 +01:00
chat_handler = chat_handler ,
2023-11-02 13:40:20 -04:00
# Misc
2023-05-07 05:09:10 -04:00
verbose = settings . verbose ,
2023-04-28 23:47:36 -07:00
)
if settings . cache :
2023-06-08 13:19:23 -04:00
if settings . cache_type == " disk " :
2023-06-14 21:46:48 -04:00
if settings . verbose :
print ( f " Using disk cache with size { settings . cache_size } " )
2023-06-08 13:19:23 -04:00
cache = llama_cpp . LlamaDiskCache ( capacity_bytes = settings . cache_size )
else :
2023-06-14 21:46:48 -04:00
if settings . verbose :
print ( f " Using ram cache with size { settings . cache_size } " )
2023-06-08 13:19:23 -04:00
cache = llama_cpp . LlamaRAMCache ( capacity_bytes = settings . cache_size )
2023-05-07 19:33:17 -04:00
cache = llama_cpp . LlamaCache ( capacity_bytes = settings . cache_size )
2023-04-28 23:47:36 -07:00
llama . set_cache ( cache )
2023-05-16 17:22:00 -04:00
def set_settings ( _settings : Settings ) :
global settings
settings = _settings
set_settings ( settings )
2023-05-01 22:38:46 -04:00
return app
2023-04-28 22:43:37 -07:00
2023-07-07 03:04:17 -04:00
llama_outer_lock = Lock ( )
llama_inner_lock = Lock ( )
2023-05-01 22:38:46 -04:00
2023-04-28 22:43:37 -07:00
def get_llama ( ) :
2023-07-07 03:04:17 -04:00
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock . acquire ( )
release_outer_lock = True
try :
llama_inner_lock . acquire ( )
try :
llama_outer_lock . release ( )
release_outer_lock = False
yield llama
finally :
llama_inner_lock . release ( )
finally :
if release_outer_lock :
llama_outer_lock . release ( )
2023-04-28 22:43:37 -07:00
2023-05-07 02:52:20 -04:00
2023-05-16 17:22:00 -04:00
def get_settings ( ) :
yield settings
2023-07-16 14:57:39 +09:00
async def get_event_publisher (
request : Request ,
inner_send_chan : MemoryObjectSendStream ,
iterator : Iterator ,
) :
async with inner_send_chan :
try :
async for chunk in iterate_in_threadpool ( iterator ) :
await inner_send_chan . send ( dict ( data = json . dumps ( chunk ) ) )
if await request . is_disconnected ( ) :
raise anyio . get_cancelled_exc_class ( ) ( )
if settings . interrupt_requests and llama_outer_lock . locked ( ) :
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
raise anyio . get_cancelled_exc_class ( ) ( )
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
except anyio . get_cancelled_exc_class ( ) as e :
print ( " disconnected " )
with anyio . move_on_after ( 1 , shield = True ) :
2023-09-13 21:23:23 -04:00
print ( f " Disconnected from client (via refresh/close) { request . client } " )
2023-07-16 14:57:39 +09:00
raise e
2023-09-13 21:23:23 -04:00
model_field = Field (
description = " The model to use for generating completions. " , default = None
)
2023-04-29 00:47:35 -07:00
2023-04-29 18:37:43 -07:00
max_tokens_field = Field (
2023-07-16 14:57:39 +09:00
default = 16 , ge = 1 , description = " The maximum number of tokens to generate. "
2023-04-29 18:37:43 -07:00
)
temperature_field = Field (
default = 0.8 ,
ge = 0.0 ,
le = 2.0 ,
2023-05-07 02:52:20 -04:00
description = " Adjust the randomness of the generated text. \n \n "
+ " Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model ' s output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. " ,
2023-04-29 18:37:43 -07:00
)
top_p_field = Field (
default = 0.95 ,
ge = 0.0 ,
le = 1.0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. \n \n "
+ " Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. " ,
2023-04-29 18:37:43 -07:00
)
stop_field = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A list of tokens at which to stop generation. If None, no stop tokens are used. " ,
2023-04-29 18:37:43 -07:00
)
stream_field = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to stream the results as they are generated. Useful for chatbots. " ,
2023-04-29 18:37:43 -07:00
)
top_k_field = Field (
default = 40 ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to the K most probable tokens. \n \n "
+ " Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. " ,
2023-04-29 18:37:43 -07:00
)
repeat_penalty_field = Field (
2023-05-08 18:49:11 -04:00
default = 1.1 ,
2023-04-29 18:37:43 -07:00
ge = 0.0 ,
2023-05-07 02:52:20 -04:00
description = " A penalty applied to each token that is already generated. This helps prevent the model from repeating itself. \n \n "
+ " Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. " ,
2023-04-29 18:37:43 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on whether they appear in the text so far, increasing the model ' s likelihood to talk about new topics. " ,
)
frequency_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model ' s likelihood to repeat the same line verbatim. " ,
)
2023-05-07 02:52:20 -04:00
2023-06-05 22:37:11 -04:00
mirostat_mode_field = Field (
default = 0 ,
ge = 0 ,
le = 2 ,
2023-07-13 23:25:12 -04:00
description = " Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled) " ,
2023-06-05 22:37:11 -04:00
)
mirostat_tau_field = Field (
default = 5.0 ,
ge = 0.0 ,
le = 10.0 ,
2023-07-13 23:25:12 -04:00
description = " Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text " ,
2023-06-05 22:37:11 -04:00
)
mirostat_eta_field = Field (
2023-07-13 23:25:12 -04:00
default = 0.1 , ge = 0.001 , le = 1.0 , description = " Mirostat learning rate "
2023-06-05 22:37:11 -04:00
)
2023-11-01 23:51:12 +01:00
grammar = Field (
default = None ,
description = " A CBNF grammar (as string) to be used for formatting the model ' s output. "
)
2023-05-12 07:21:46 -04:00
2023-04-28 22:43:37 -07:00
class CreateCompletionRequest ( BaseModel ) :
2023-05-12 07:16:57 -04:00
prompt : Union [ str , List [ str ] ] = Field (
2023-05-12 07:21:46 -04:00
default = " " , description = " The prompt to generate completions for. "
2023-04-29 14:37:36 -07:00
)
suffix : Optional [ str ] = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-04-29 18:37:43 -07:00
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-04-29 14:37:36 -07:00
echo : bool = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to echo the prompt in the generated text. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-05-19 03:15:08 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = stop_field
2023-04-29 18:37:43 -07:00
stream : bool = stream_field
2023-04-29 14:37:36 -07:00
logprobs : Optional [ int ] = Field (
default = None ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " The number of logprobs to generate. If None, no logprobs are generated. " ,
2023-04-29 14:37:36 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 13:13:08 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-06-14 22:08:28 -04:00
logprobs : Optional [ int ] = Field ( None )
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = Field ( None )
2023-04-29 14:37:36 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
best_of : Optional [ int ] = 1
2023-07-13 23:25:12 -04:00
user : Optional [ str ] = Field ( default = None )
2023-04-29 00:47:35 -07:00
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-14 22:08:28 -04:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-11-08 04:48:51 +01:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
grammar : Optional [ str ] = None
2023-04-28 22:43:37 -07:00
2023-07-13 23:25:12 -04:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" prompt " : " \n \n ### Instructions: \n What is the capital of France? \n \n ### Response: \n " ,
" stop " : [ " \n " , " ### " ] ,
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-06-09 13:13:08 -04:00
def make_logit_bias_processor (
llama : llama_cpp . Llama ,
logit_bias : Dict [ str , float ] ,
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] ,
) :
if logit_bias_type is None :
logit_bias_type = " input_ids "
to_bias : Dict [ int , float ] = { }
if logit_bias_type == " input_ids " :
for input_id , score in logit_bias . items ( ) :
input_id = int ( input_id )
to_bias [ input_id ] = score
elif logit_bias_type == " tokens " :
for token , score in logit_bias . items ( ) :
2023-07-13 23:25:12 -04:00
token = token . encode ( " utf-8 " )
2023-11-02 01:29:06 +00:00
for input_id in llama . tokenize ( token , add_bos = False , special = True ) :
2023-06-09 13:13:08 -04:00
to_bias [ input_id ] = score
def logit_bias_processor (
2023-07-18 19:27:41 -04:00
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
) - > npt . NDArray [ np . single ] :
2023-11-01 23:53:47 +01:00
new_scores = np . copy ( scores ) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id , score in to_bias . items ( ) :
new_scores [ input_id ] = score + scores [ input_id ]
2023-06-09 13:13:08 -04:00
return new_scores
return logit_bias_processor
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/completions " ,
)
2023-08-25 17:49:14 -04:00
@router.post ( " /v1/engines/copilot-codex/completions " )
2023-05-27 09:12:58 -04:00
async def create_completion (
request : Request ,
body : CreateCompletionRequest ,
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-13 23:25:12 -04:00
) - > llama_cpp . Completion :
2023-05-27 09:12:58 -04:00
if isinstance ( body . prompt , list ) :
assert len ( body . prompt ) < = 1
body . prompt = body . prompt [ 0 ] if len ( body . prompt ) > 0 else " "
exclude = {
" n " ,
" best_of " ,
" logit_bias " ,
2023-06-09 13:13:08 -04:00
" logit_bias_type " ,
2023-05-27 09:12:58 -04:00
" user " ,
}
2023-07-13 23:25:12 -04:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 13:13:08 -04:00
if body . logit_bias is not None :
2023-07-19 03:48:27 -04:00
kwargs [ " logits_processor " ] = llama_cpp . LogitsProcessorList (
[
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
]
)
2023-06-09 13:13:08 -04:00
2023-11-01 23:51:12 +01:00
if body . grammar is not None :
kwargs [ " grammar " ] = llama_cpp . LlamaGrammar . from_string ( body . grammar )
2023-09-13 21:23:23 -04:00
iterator_or_completion : Union [
2023-11-08 04:48:51 +01:00
llama_cpp . CreateCompletionResponse , Iterator [ llama_cpp . CreateCompletionStreamResponse ]
2023-09-13 21:23:23 -04:00
] = await run_in_threadpool ( llama , * * kwargs )
2023-05-27 09:12:58 -04:00
2023-07-16 14:57:39 +09:00
if isinstance ( iterator_or_completion , Iterator ) :
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool ( next , iterator_or_completion )
2023-05-19 02:04:30 -04:00
2023-07-16 14:57:39 +09:00
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
2023-11-08 04:48:51 +01:00
def iterator ( ) - > Iterator [ llama_cpp . CreateCompletionStreamResponse ] :
2023-07-16 14:57:39 +09:00
yield first_response
yield from iterator_or_completion
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
2023-05-27 09:12:58 -04:00
return EventSourceResponse (
2023-09-13 21:23:23 -04:00
recv_chan ,
data_sender_callable = partial ( # type: ignore
2023-07-16 14:57:39 +09:00
get_event_publisher ,
request = request ,
inner_send_chan = send_chan ,
iterator = iterator ( ) ,
2023-09-13 21:23:23 -04:00
) ,
2023-07-16 14:57:39 +09:00
)
2023-05-27 09:12:58 -04:00
else :
2023-07-16 14:57:39 +09:00
return iterator_or_completion
2023-04-28 22:43:37 -07:00
class CreateEmbeddingRequest ( BaseModel ) :
2023-05-07 02:00:22 -04:00
model : Optional [ str ] = model_field
2023-05-20 01:23:32 +02:00
input : Union [ str , List [ str ] ] = Field ( description = " The input to embed. " )
2023-07-13 23:25:12 -04:00
user : Optional [ str ] = Field ( default = None )
model_config = {
" json_schema_extra " : {
" examples " : [
{
" input " : " The food was delicious and the waiter... " ,
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/embeddings " ,
)
2023-05-27 09:12:58 -04:00
async def create_embedding (
2023-04-28 22:43:37 -07:00
request : CreateEmbeddingRequest , llama : llama_cpp . Llama = Depends ( get_llama )
) :
2023-05-27 09:12:58 -04:00
return await run_in_threadpool (
2023-07-13 23:25:12 -04:00
llama . create_embedding , * * request . model_dump ( exclude = { " user " } )
2023-05-27 09:12:58 -04:00
)
2023-04-28 22:43:37 -07:00
class ChatCompletionRequestMessage ( BaseModel ) :
2023-11-03 02:12:14 -04:00
role : Literal [ " system " , " user " , " assistant " , " function " ] = Field (
2023-05-01 11:48:37 -07:00
default = " user " , description = " The role of the message. "
2023-04-29 18:37:43 -07:00
)
2023-11-03 02:12:14 -04:00
content : Optional [ str ] = Field ( default = " " , description = " The content of the message. " )
2023-04-28 22:43:37 -07:00
class CreateChatCompletionRequest ( BaseModel ) :
2023-11-08 04:48:51 +01:00
messages : List [ llama_cpp . ChatCompletionRequestMessage ] = Field (
2023-05-07 02:52:20 -04:00
default = [ ] , description = " A list of messages to generate completions for. "
2023-04-29 18:37:43 -07:00
)
2023-07-19 03:48:20 -04:00
functions : Optional [ List [ llama_cpp . ChatCompletionFunction ] ] = Field (
default = None ,
description = " A list of functions to apply to the generated completions. " ,
)
2023-11-08 04:48:51 +01:00
function_call : Optional [ llama_cpp . ChatCompletionRequestFunctionCall ] = Field (
2023-07-19 03:48:20 -04:00
default = None ,
description = " A function to apply to the generated completions. " ,
)
2023-11-08 04:48:51 +01:00
tools : Optional [ List [ llama_cpp . ChatCompletionTool ] ] = Field (
default = None ,
description = " A list of tools to apply to the generated completions. " ,
)
tool_choice : Optional [ llama_cpp . ChatCompletionToolChoiceOption ] = Field (
default = None ,
description = " A tool to apply to the generated completions. " ,
) # TODO: verify
2023-11-10 02:49:27 -05:00
max_tokens : Optional [ int ] = Field (
default = None , description = " The maximum number of tokens to generate. Defaults to inf "
)
2023-04-29 18:37:43 -07:00
temperature : float = temperature_field
top_p : float = top_p_field
stop : Optional [ List [ str ] ] = stop_field
stream : bool = stream_field
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 13:13:08 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = Field ( None )
2023-11-08 00:07:16 -05:00
response_format : Optional [ llama_cpp . ChatCompletionRequestResponseFormat ] = Field (
default = None ,
)
2023-04-28 22:43:37 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
user : Optional [ str ] = Field ( None )
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-14 22:08:28 -04:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-11-08 04:48:51 +01:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
grammar : Optional [ str ] = None
2023-04-28 22:43:37 -07:00
2023-07-13 23:25:12 -04:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" messages " : [
ChatCompletionRequestMessage (
role = " system " , content = " You are a helpful assistant. "
) . model_dump ( ) ,
ChatCompletionRequestMessage (
role = " user " , content = " What is the capital of France? "
) . model_dump ( ) ,
]
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/chat/completions " ,
)
2023-05-27 09:12:58 -04:00
async def create_chat_completion (
request : Request ,
body : CreateChatCompletionRequest ,
2023-04-28 22:43:37 -07:00
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-07 03:37:23 -04:00
settings : Settings = Depends ( get_settings ) ,
2023-07-13 23:25:12 -04:00
) - > llama_cpp . ChatCompletion :
2023-05-27 09:12:58 -04:00
exclude = {
" n " ,
" logit_bias " ,
2023-06-09 13:13:08 -04:00
" logit_bias_type " ,
2023-05-27 09:12:58 -04:00
" user " ,
}
2023-07-13 23:25:12 -04:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 13:13:08 -04:00
if body . logit_bias is not None :
2023-07-19 03:48:27 -04:00
kwargs [ " logits_processor " ] = llama_cpp . LogitsProcessorList (
[
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
]
)
2023-06-09 13:13:08 -04:00
2023-11-01 23:51:12 +01:00
if body . grammar is not None :
kwargs [ " grammar " ] = llama_cpp . LlamaGrammar . from_string ( body . grammar )
2023-09-13 21:23:23 -04:00
iterator_or_completion : Union [
llama_cpp . ChatCompletion , Iterator [ llama_cpp . ChatCompletionChunk ]
] = await run_in_threadpool ( llama . create_chat_completion , * * kwargs )
2023-05-27 09:12:58 -04:00
2023-07-16 14:57:39 +09:00
if isinstance ( iterator_or_completion , Iterator ) :
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool ( next , iterator_or_completion )
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator ( ) - > Iterator [ llama_cpp . ChatCompletionChunk ] :
yield first_response
yield from iterator_or_completion
2023-04-28 22:43:37 -07:00
2023-07-16 14:57:39 +09:00
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
2023-04-28 22:43:37 -07:00
return EventSourceResponse (
2023-09-13 21:23:23 -04:00
recv_chan ,
data_sender_callable = partial ( # type: ignore
2023-07-16 14:57:39 +09:00
get_event_publisher ,
request = request ,
inner_send_chan = send_chan ,
iterator = iterator ( ) ,
2023-09-13 21:23:23 -04:00
) ,
2023-04-28 22:43:37 -07:00
)
2023-07-16 14:57:39 +09:00
else :
return iterator_or_completion
2023-04-28 22:43:37 -07:00
class ModelData ( TypedDict ) :
id : str
object : Literal [ " model " ]
owned_by : str
permissions : List [ str ]
class ModelList ( TypedDict ) :
object : Literal [ " list " ]
data : List [ ModelData ]
2023-07-07 21:38:46 -04:00
@router.get ( " /v1/models " )
2023-05-27 09:12:58 -04:00
async def get_models (
2023-05-16 17:22:00 -04:00
settings : Settings = Depends ( get_settings ) ,
2023-05-07 20:17:52 -04:00
) - > ModelList :
2023-07-07 03:04:17 -04:00
assert llama is not None
2023-04-28 22:43:37 -07:00
return {
" object " : " list " ,
" data " : [
{
2023-05-16 17:22:00 -04:00
" id " : settings . model_alias
if settings . model_alias is not None
else llama . model_path ,
2023-04-28 22:43:37 -07:00
" object " : " model " ,
" owned_by " : " me " ,
" permissions " : [ ] ,
}
] ,
}