2023-04-29 05:43:37 +00:00
import json
2023-05-07 07:03:57 +00:00
import multiprocessing
2023-07-16 05:57:39 +00:00
from re import compile , Match , Pattern
2023-04-29 05:43:37 +00:00
from threading import Lock
2023-05-27 13:12:58 +00:00
from functools import partial
2023-07-20 22:52:10 +00:00
from typing import Callable , Coroutine , Iterator , List , Optional , Tuple , Union , Dict
2023-05-07 07:03:57 +00:00
from typing_extensions import TypedDict , Literal
2023-04-29 05:43:37 +00:00
import llama_cpp
2023-05-27 13:12:58 +00:00
import anyio
from anyio . streams . memory import MemoryObjectSendStream
from starlette . concurrency import run_in_threadpool , iterate_in_threadpool
2023-07-16 05:57:39 +00:00
from fastapi import Depends , FastAPI , APIRouter , Request , Response
2023-09-13 20:18:31 +00:00
from fastapi . middleware import Middleware
2023-04-29 05:43:37 +00:00
from fastapi . middleware . cors import CORSMiddleware
2023-07-16 05:57:39 +00:00
from fastapi . responses import JSONResponse
from fastapi . routing import APIRoute
2023-07-08 01:38:46 +00:00
from pydantic import BaseModel , Field
from pydantic_settings import BaseSettings
2023-04-29 05:43:37 +00:00
from sse_starlette . sse import EventSourceResponse
2023-09-13 20:18:31 +00:00
from starlette_context import plugins
from starlette_context . middleware import RawContextMiddleware
2023-04-29 05:43:37 +00:00
2023-07-18 23:27:41 +00:00
import numpy as np
import numpy . typing as npt
2023-04-29 05:43:37 +00:00
2023-09-14 03:01:34 +00:00
BaseSettings . model_config [ ' protected_namespaces ' ] = ( )
2023-04-29 05:43:37 +00:00
class Settings ( BaseSettings ) :
2023-05-07 06:52:20 +00:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
2023-05-16 21:22:00 +00:00
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
2023-09-14 01:23:13 +00:00
seed : int = Field ( default = llama_cpp . LLAMA_DEFAULT_SEED , description = " Random seed. -1 for random. " )
2023-05-07 06:52:20 +00:00
n_ctx : int = Field ( default = 2048 , ge = 1 , description = " The context size. " )
2023-09-14 01:23:13 +00:00
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
2023-05-14 04:04:22 +00:00
n_gpu_layers : int = Field (
default = 0 ,
ge = 0 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. " ,
)
2023-09-14 01:23:13 +00:00
main_gpu : int = Field (
default = 0 ,
ge = 0 ,
description = " Main GPU to use. " ,
)
2023-07-14 20:52:48 +00:00
tensor_split : Optional [ List [ float ] ] = Field (
2023-07-07 09:22:10 +00:00
default = None ,
description = " Split layers across multiple GPUs in proportion. " ,
)
2023-07-19 07:48:27 +00:00
rope_freq_base : float = Field (
default = 10000 , ge = 1 , description = " RoPE base frequency "
2023-06-22 20:19:24 +00:00
)
2023-07-19 07:48:27 +00:00
rope_freq_scale : float = Field (
default = 1.0 , description = " RoPE frequency scaling factor "
)
2023-09-14 01:23:13 +00:00
low_vram : bool = Field (
default = False ,
description = " Whether to use less VRAM. This will reduce performance. " ,
2023-05-07 06:52:20 +00:00
)
2023-09-14 01:23:13 +00:00
mul_mat_q : bool = Field (
default = True , description = " if true, use experimental mul_mat_q kernels "
2023-05-07 06:52:20 +00:00
)
f16_kv : bool = Field ( default = True , description = " Whether to use f16 key/value. " )
2023-09-14 01:23:13 +00:00
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
2023-05-07 06:52:20 +00:00
)
use_mmap : bool = Field (
2023-05-07 07:04:22 +00:00
default = llama_cpp . llama_mmap_supported ( ) ,
2023-05-07 06:52:20 +00:00
description = " Use mmap. " ,
)
2023-09-14 01:23:13 +00:00
use_mlock : bool = Field (
default = llama_cpp . llama_mlock_supported ( ) ,
description = " Use mlock. " ,
)
2023-05-07 06:52:20 +00:00
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
2023-09-14 01:23:13 +00:00
n_threads : int = Field (
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
ge = 1 ,
description = " The number of threads to use. " ,
2023-06-15 02:13:42 +00:00
)
2023-05-07 06:52:20 +00:00
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
2023-09-14 01:23:13 +00:00
lora_base : Optional [ str ] = Field (
default = None ,
description = " Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. "
)
lora_path : Optional [ str ] = Field (
default = None ,
description = " Path to a LoRA file to apply to the model. " ,
)
2023-09-14 03:00:43 +00:00
numa : bool = Field (
default = False ,
description = " Enable NUMA support. " ,
)
2023-05-07 06:52:20 +00:00
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
2023-06-08 17:19:23 +00:00
cache_type : Literal [ " ram " , " disk " ] = Field (
default = " ram " ,
description = " The type of cache to use. Only used if cache is True. " ,
)
2023-05-07 23:33:17 +00:00
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2023-05-07 09:09:10 +00:00
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
2023-07-14 03:25:12 +00:00
host : str = Field ( default = " localhost " , description = " Listen address " )
port : int = Field ( default = 8000 , description = " Listen port " )
2023-07-07 07:37:23 +00:00
interrupt_requests : bool = Field (
default = True ,
description = " Whether to interrupt requests when a new request is received. " ,
)
2023-04-29 05:43:37 +00:00
2023-07-16 05:57:39 +00:00
class ErrorResponse ( TypedDict ) :
""" OpenAI style error response """
message : str
type : str
param : Optional [ str ]
code : Optional [ str ]
class ErrorResponseFormatters :
""" Collection of formatters for error responses.
Args :
request ( Union [ CreateCompletionRequest , CreateChatCompletionRequest ] ) :
Request body
match ( Match [ str ] ) : Match object from regex pattern
Returns :
2023-07-20 22:52:10 +00:00
Tuple [ int , ErrorResponse ] : Status code and error response
2023-07-16 05:57:39 +00:00
"""
@staticmethod
def context_length_exceeded (
2023-09-14 01:23:23 +00:00
request : Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
match , # type: Match[str] # type: ignore
2023-07-20 22:52:10 +00:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 05:57:39 +00:00
""" Formatter for context length exceeded error """
context_window = int ( match . group ( 2 ) )
prompt_tokens = int ( match . group ( 1 ) )
completion_tokens = request . max_tokens
if hasattr ( request , " messages " ) :
# Chat completion
message = (
" This model ' s maximum context length is {} tokens. "
" However, you requested {} tokens "
" ( {} in the messages, {} in the completion). "
" Please reduce the length of the messages or completion. "
)
else :
# Text completion
message = (
" This model ' s maximum context length is {} tokens, "
" however you requested {} tokens "
" ( {} in your prompt; {} for the completion). "
" Please reduce your prompt; or completion length. "
)
return 400 , ErrorResponse (
message = message . format (
context_window ,
completion_tokens + prompt_tokens ,
prompt_tokens ,
completion_tokens ,
) ,
type = " invalid_request_error " ,
param = " messages " ,
code = " context_length_exceeded " ,
)
@staticmethod
def model_not_found (
2023-09-14 01:23:23 +00:00
request : Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
match , # type: Match[str] # type: ignore
2023-07-20 22:52:10 +00:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 05:57:39 +00:00
""" Formatter for model_not_found error """
model_path = str ( match . group ( 1 ) )
message = f " The model ` { model_path } ` does not exist "
return 400 , ErrorResponse (
message = message ,
type = " invalid_request_error " ,
param = None ,
code = " model_not_found " ,
)
class RouteErrorHandler ( APIRoute ) :
""" Custom APIRoute that handles application errors and exceptions """
# key: regex pattern for original error message from llama_cpp
# value: formatter function
2023-07-20 22:52:10 +00:00
pattern_and_formatters : Dict [
2023-07-16 05:57:39 +00:00
" Pattern " ,
Callable [
[
Union [ " CreateCompletionRequest " , " CreateChatCompletionRequest " ] ,
2023-07-20 22:52:10 +00:00
" Match[str] " ,
2023-07-16 05:57:39 +00:00
] ,
2023-07-20 22:52:10 +00:00
Tuple [ int , ErrorResponse ] ,
2023-07-16 05:57:39 +00:00
] ,
] = {
compile (
r " Requested tokens \ (( \ d+) \ ) exceed context window of ( \ d+) "
) : ErrorResponseFormatters . context_length_exceeded ,
compile (
r " Model path does not exist: (.+) "
) : ErrorResponseFormatters . model_not_found ,
}
def error_message_wrapper (
self ,
error : Exception ,
body : Optional [
Union [
" CreateChatCompletionRequest " ,
" CreateCompletionRequest " ,
" CreateEmbeddingRequest " ,
]
] = None ,
2023-07-20 22:52:10 +00:00
) - > Tuple [ int , ErrorResponse ] :
2023-07-16 05:57:39 +00:00
""" Wraps error message in OpenAI style error response """
if body is not None and isinstance (
body ,
(
CreateCompletionRequest ,
CreateChatCompletionRequest ,
) ,
) :
# When text completion or chat completion
for pattern , callback in self . pattern_and_formatters . items ( ) :
match = pattern . search ( str ( error ) )
if match is not None :
return callback ( body , match )
# Wrap other errors as internal server error
return 500 , ErrorResponse (
message = str ( error ) ,
type = " internal_server_error " ,
param = None ,
code = None ,
)
def get_route_handler (
self ,
) - > Callable [ [ Request ] , Coroutine [ None , None , Response ] ] :
""" Defines custom route handler that catches exceptions and formats
in OpenAI style error response """
original_route_handler = super ( ) . get_route_handler ( )
async def custom_route_handler ( request : Request ) - > Response :
try :
return await original_route_handler ( request )
except Exception as exc :
json_body = await request . json ( )
try :
if " messages " in json_body :
# Chat completion
body : Optional [
Union [
CreateChatCompletionRequest ,
CreateCompletionRequest ,
CreateEmbeddingRequest ,
]
] = CreateChatCompletionRequest ( * * json_body )
elif " prompt " in json_body :
# Text completion
body = CreateCompletionRequest ( * * json_body )
else :
# Embedding
body = CreateEmbeddingRequest ( * * json_body )
except Exception :
# Invalid request body
body = None
# Get proper error message from the exception
(
status_code ,
error_message ,
) = self . error_message_wrapper ( error = exc , body = body )
return JSONResponse (
{ " error " : error_message } ,
status_code = status_code ,
)
return custom_route_handler
router = APIRouter ( route_class = RouteErrorHandler )
2023-05-02 02:38:46 +00:00
2023-05-16 21:22:00 +00:00
settings : Optional [ Settings ] = None
2023-05-02 02:38:46 +00:00
llama : Optional [ llama_cpp . Llama ] = None
2023-04-29 05:43:37 +00:00
2023-05-02 02:38:46 +00:00
def create_app ( settings : Optional [ Settings ] = None ) :
2023-04-29 06:47:36 +00:00
if settings is None :
settings = Settings ( )
2023-09-13 20:18:31 +00:00
middleware = [
2023-09-14 01:23:23 +00:00
Middleware ( RawContextMiddleware , plugins = ( plugins . RequestIdPlugin ( ) , ) )
2023-09-13 20:18:31 +00:00
]
2023-05-02 02:38:46 +00:00
app = FastAPI (
2023-09-13 20:18:31 +00:00
middleware = middleware ,
2023-05-02 02:38:46 +00:00
title = " 🦙 llama.cpp Python API " ,
version = " 0.0.1 " ,
)
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
app . include_router ( router )
2023-04-29 06:47:36 +00:00
global llama
llama = llama_cpp . Llama (
2023-05-02 02:38:46 +00:00
model_path = settings . model ,
2023-09-14 01:23:13 +00:00
seed = settings . seed ,
n_ctx = settings . n_ctx ,
n_batch = settings . n_batch ,
2023-05-14 04:04:22 +00:00
n_gpu_layers = settings . n_gpu_layers ,
2023-09-14 01:23:13 +00:00
main_gpu = settings . main_gpu ,
2023-07-07 09:22:10 +00:00
tensor_split = settings . tensor_split ,
2023-07-18 08:34:36 +00:00
rope_freq_base = settings . rope_freq_base ,
rope_freq_scale = settings . rope_freq_scale ,
2023-09-14 01:23:13 +00:00
low_vram = settings . low_vram ,
mul_mat_q = settings . mul_mat_q ,
2023-04-29 06:47:36 +00:00
f16_kv = settings . f16_kv ,
2023-09-14 01:23:13 +00:00
logits_all = settings . logits_all ,
vocab_only = settings . vocab_only ,
2023-04-29 06:47:36 +00:00
use_mmap = settings . use_mmap ,
2023-09-14 01:23:13 +00:00
use_mlock = settings . use_mlock ,
2023-04-29 06:47:36 +00:00
embedding = settings . embedding ,
n_threads = settings . n_threads ,
last_n_tokens_size = settings . last_n_tokens_size ,
2023-09-14 01:23:13 +00:00
lora_base = settings . lora_base ,
lora_path = settings . lora_path ,
2023-05-07 09:09:10 +00:00
verbose = settings . verbose ,
2023-04-29 06:47:36 +00:00
)
if settings . cache :
2023-06-08 17:19:23 +00:00
if settings . cache_type == " disk " :
2023-06-15 01:46:48 +00:00
if settings . verbose :
print ( f " Using disk cache with size { settings . cache_size } " )
2023-06-08 17:19:23 +00:00
cache = llama_cpp . LlamaDiskCache ( capacity_bytes = settings . cache_size )
else :
2023-06-15 01:46:48 +00:00
if settings . verbose :
print ( f " Using ram cache with size { settings . cache_size } " )
2023-06-08 17:19:23 +00:00
cache = llama_cpp . LlamaRAMCache ( capacity_bytes = settings . cache_size )
2023-05-07 23:33:17 +00:00
cache = llama_cpp . LlamaCache ( capacity_bytes = settings . cache_size )
2023-04-29 06:47:36 +00:00
llama . set_cache ( cache )
2023-05-16 21:22:00 +00:00
def set_settings ( _settings : Settings ) :
global settings
settings = _settings
set_settings ( settings )
2023-05-02 02:38:46 +00:00
return app
2023-04-29 05:43:37 +00:00
2023-07-07 07:04:17 +00:00
llama_outer_lock = Lock ( )
llama_inner_lock = Lock ( )
2023-05-02 02:38:46 +00:00
2023-04-29 05:43:37 +00:00
def get_llama ( ) :
2023-07-07 07:04:17 +00:00
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock . acquire ( )
release_outer_lock = True
try :
llama_inner_lock . acquire ( )
try :
llama_outer_lock . release ( )
release_outer_lock = False
yield llama
finally :
llama_inner_lock . release ( )
finally :
if release_outer_lock :
llama_outer_lock . release ( )
2023-04-29 05:43:37 +00:00
2023-05-07 06:52:20 +00:00
2023-05-16 21:22:00 +00:00
def get_settings ( ) :
yield settings
2023-07-16 05:57:39 +00:00
async def get_event_publisher (
request : Request ,
inner_send_chan : MemoryObjectSendStream ,
iterator : Iterator ,
) :
async with inner_send_chan :
try :
async for chunk in iterate_in_threadpool ( iterator ) :
await inner_send_chan . send ( dict ( data = json . dumps ( chunk ) ) )
if await request . is_disconnected ( ) :
raise anyio . get_cancelled_exc_class ( ) ( )
if settings . interrupt_requests and llama_outer_lock . locked ( ) :
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
raise anyio . get_cancelled_exc_class ( ) ( )
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
except anyio . get_cancelled_exc_class ( ) as e :
print ( " disconnected " )
with anyio . move_on_after ( 1 , shield = True ) :
2023-09-14 01:23:23 +00:00
print ( f " Disconnected from client (via refresh/close) { request . client } " )
2023-07-16 05:57:39 +00:00
raise e
2023-09-14 01:23:23 +00:00
model_field = Field (
description = " The model to use for generating completions. " , default = None
)
2023-04-29 07:47:35 +00:00
2023-04-30 01:37:43 +00:00
max_tokens_field = Field (
2023-07-16 05:57:39 +00:00
default = 16 , ge = 1 , description = " The maximum number of tokens to generate. "
2023-04-30 01:37:43 +00:00
)
temperature_field = Field (
default = 0.8 ,
ge = 0.0 ,
le = 2.0 ,
2023-05-07 06:52:20 +00:00
description = " Adjust the randomness of the generated text. \n \n "
+ " Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model ' s output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. " ,
2023-04-30 01:37:43 +00:00
)
top_p_field = Field (
default = 0.95 ,
ge = 0.0 ,
le = 1.0 ,
2023-05-07 06:52:20 +00:00
description = " Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. \n \n "
+ " Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. " ,
2023-04-30 01:37:43 +00:00
)
stop_field = Field (
default = None ,
2023-05-07 06:52:20 +00:00
description = " A list of tokens at which to stop generation. If None, no stop tokens are used. " ,
2023-04-30 01:37:43 +00:00
)
stream_field = Field (
default = False ,
2023-05-07 06:52:20 +00:00
description = " Whether to stream the results as they are generated. Useful for chatbots. " ,
2023-04-30 01:37:43 +00:00
)
top_k_field = Field (
default = 40 ,
ge = 0 ,
2023-05-07 06:52:20 +00:00
description = " Limit the next token selection to the K most probable tokens. \n \n "
+ " Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. " ,
2023-04-30 01:37:43 +00:00
)
repeat_penalty_field = Field (
2023-05-08 22:49:11 +00:00
default = 1.1 ,
2023-04-30 01:37:43 +00:00
ge = 0.0 ,
2023-05-07 06:52:20 +00:00
description = " A penalty applied to each token that is already generated. This helps prevent the model from repeating itself. \n \n "
+ " Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. " ,
2023-04-30 01:37:43 +00:00
)
2023-05-09 23:19:46 +00:00
presence_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on whether they appear in the text so far, increasing the model ' s likelihood to talk about new topics. " ,
)
frequency_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model ' s likelihood to repeat the same line verbatim. " ,
)
2023-05-07 06:52:20 +00:00
2023-06-06 02:37:11 +00:00
mirostat_mode_field = Field (
default = 0 ,
ge = 0 ,
le = 2 ,
2023-07-14 03:25:12 +00:00
description = " Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled) " ,
2023-06-06 02:37:11 +00:00
)
mirostat_tau_field = Field (
default = 5.0 ,
ge = 0.0 ,
le = 10.0 ,
2023-07-14 03:25:12 +00:00
description = " Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text " ,
2023-06-06 02:37:11 +00:00
)
mirostat_eta_field = Field (
2023-07-14 03:25:12 +00:00
default = 0.1 , ge = 0.001 , le = 1.0 , description = " Mirostat learning rate "
2023-06-06 02:37:11 +00:00
)
2023-05-12 11:21:46 +00:00
2023-04-29 05:43:37 +00:00
class CreateCompletionRequest ( BaseModel ) :
2023-05-12 11:16:57 +00:00
prompt : Union [ str , List [ str ] ] = Field (
2023-05-12 11:21:46 +00:00
default = " " , description = " The prompt to generate completions for. "
2023-04-29 21:37:36 +00:00
)
suffix : Optional [ str ] = Field (
default = None ,
2023-05-07 06:52:20 +00:00
description = " A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots. " ,
2023-04-29 21:37:36 +00:00
)
2023-04-30 01:37:43 +00:00
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-06-06 02:37:11 +00:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
2023-04-29 21:37:36 +00:00
echo : bool = Field (
default = False ,
2023-05-07 06:52:20 +00:00
description = " Whether to echo the prompt in the generated text. Useful for chatbots. " ,
2023-04-29 21:37:36 +00:00
)
2023-05-19 07:15:08 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = stop_field
2023-04-30 01:37:43 +00:00
stream : bool = stream_field
2023-04-29 21:37:36 +00:00
logprobs : Optional [ int ] = Field (
default = None ,
ge = 0 ,
2023-05-07 06:52:20 +00:00
description = " The number of logprobs to generate. If None, no logprobs are generated. " ,
2023-04-29 21:37:36 +00:00
)
2023-05-09 23:19:46 +00:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 17:13:08 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-06-15 02:08:28 +00:00
logprobs : Optional [ int ] = Field ( None )
2023-04-29 21:37:36 +00:00
2023-05-07 06:00:22 +00:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
best_of : Optional [ int ] = 1
2023-07-14 03:25:12 +00:00
user : Optional [ str ] = Field ( default = None )
2023-04-29 07:47:35 +00:00
2023-04-29 05:43:37 +00:00
# llama.cpp specific parameters
2023-04-30 01:37:43 +00:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-15 02:08:28 +00:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-04-29 05:43:37 +00:00
2023-07-14 03:25:12 +00:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" prompt " : " \n \n ### Instructions: \n What is the capital of France? \n \n ### Response: \n " ,
" stop " : [ " \n " , " ### " ] ,
}
]
2023-04-29 05:43:37 +00:00
}
2023-07-14 03:25:12 +00:00
}
2023-04-29 05:43:37 +00:00
2023-06-09 17:13:08 +00:00
def make_logit_bias_processor (
llama : llama_cpp . Llama ,
logit_bias : Dict [ str , float ] ,
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] ,
) :
if logit_bias_type is None :
logit_bias_type = " input_ids "
to_bias : Dict [ int , float ] = { }
if logit_bias_type == " input_ids " :
for input_id , score in logit_bias . items ( ) :
input_id = int ( input_id )
to_bias [ input_id ] = score
elif logit_bias_type == " tokens " :
for token , score in logit_bias . items ( ) :
2023-07-14 03:25:12 +00:00
token = token . encode ( " utf-8 " )
2023-06-09 17:13:08 +00:00
for input_id in llama . tokenize ( token , add_bos = False ) :
to_bias [ input_id ] = score
def logit_bias_processor (
2023-07-18 23:27:41 +00:00
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
) - > npt . NDArray [ np . single ] :
2023-06-09 17:13:08 +00:00
new_scores = [ None ] * len ( scores )
for input_id , score in enumerate ( scores ) :
new_scores [ input_id ] = score + to_bias . get ( input_id , 0.0 )
return new_scores
return logit_bias_processor
2023-05-02 02:38:46 +00:00
@router.post (
2023-04-29 05:43:37 +00:00
" /v1/completions " ,
)
2023-08-25 21:49:14 +00:00
@router.post ( " /v1/engines/copilot-codex/completions " )
2023-05-27 13:12:58 +00:00
async def create_completion (
request : Request ,
body : CreateCompletionRequest ,
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-14 03:25:12 +00:00
) - > llama_cpp . Completion :
2023-05-27 13:12:58 +00:00
if isinstance ( body . prompt , list ) :
assert len ( body . prompt ) < = 1
body . prompt = body . prompt [ 0 ] if len ( body . prompt ) > 0 else " "
exclude = {
" n " ,
" best_of " ,
" logit_bias " ,
2023-06-09 17:13:08 +00:00
" logit_bias_type " ,
2023-05-27 13:12:58 +00:00
" user " ,
}
2023-07-14 03:25:12 +00:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 17:13:08 +00:00
if body . logit_bias is not None :
2023-07-19 07:48:27 +00:00
kwargs [ " logits_processor " ] = llama_cpp . LogitsProcessorList (
[
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
]
)
2023-06-09 17:13:08 +00:00
2023-09-14 01:23:23 +00:00
iterator_or_completion : Union [
llama_cpp . Completion , Iterator [ llama_cpp . CompletionChunk ]
] = await run_in_threadpool ( llama , * * kwargs )
2023-05-27 13:12:58 +00:00
2023-07-16 05:57:39 +00:00
if isinstance ( iterator_or_completion , Iterator ) :
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool ( next , iterator_or_completion )
2023-05-19 06:04:30 +00:00
2023-07-16 05:57:39 +00:00
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator ( ) - > Iterator [ llama_cpp . CompletionChunk ] :
yield first_response
yield from iterator_or_completion
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
2023-05-27 13:12:58 +00:00
return EventSourceResponse (
2023-09-14 01:23:23 +00:00
recv_chan ,
data_sender_callable = partial ( # type: ignore
2023-07-16 05:57:39 +00:00
get_event_publisher ,
request = request ,
inner_send_chan = send_chan ,
iterator = iterator ( ) ,
2023-09-14 01:23:23 +00:00
) ,
2023-07-16 05:57:39 +00:00
)
2023-05-27 13:12:58 +00:00
else :
2023-07-16 05:57:39 +00:00
return iterator_or_completion
2023-04-29 05:43:37 +00:00
class CreateEmbeddingRequest ( BaseModel ) :
2023-05-07 06:00:22 +00:00
model : Optional [ str ] = model_field
2023-05-19 23:23:32 +00:00
input : Union [ str , List [ str ] ] = Field ( description = " The input to embed. " )
2023-07-14 03:25:12 +00:00
user : Optional [ str ] = Field ( default = None )
model_config = {
" json_schema_extra " : {
" examples " : [
{
" input " : " The food was delicious and the waiter... " ,
}
]
2023-04-29 05:43:37 +00:00
}
2023-07-14 03:25:12 +00:00
}
2023-04-29 05:43:37 +00:00
2023-05-02 02:38:46 +00:00
@router.post (
2023-04-29 05:43:37 +00:00
" /v1/embeddings " ,
)
2023-05-27 13:12:58 +00:00
async def create_embedding (
2023-04-29 05:43:37 +00:00
request : CreateEmbeddingRequest , llama : llama_cpp . Llama = Depends ( get_llama )
) :
2023-05-27 13:12:58 +00:00
return await run_in_threadpool (
2023-07-14 03:25:12 +00:00
llama . create_embedding , * * request . model_dump ( exclude = { " user " } )
2023-05-27 13:12:58 +00:00
)
2023-04-29 05:43:37 +00:00
class ChatCompletionRequestMessage ( BaseModel ) :
2023-05-01 18:48:37 +00:00
role : Literal [ " system " , " user " , " assistant " ] = Field (
default = " user " , description = " The role of the message. "
2023-04-30 01:37:43 +00:00
)
content : str = Field ( default = " " , description = " The content of the message. " )
2023-04-29 05:43:37 +00:00
class CreateChatCompletionRequest ( BaseModel ) :
2023-04-30 01:37:43 +00:00
messages : List [ ChatCompletionRequestMessage ] = Field (
2023-05-07 06:52:20 +00:00
default = [ ] , description = " A list of messages to generate completions for. "
2023-04-30 01:37:43 +00:00
)
2023-07-19 07:48:20 +00:00
functions : Optional [ List [ llama_cpp . ChatCompletionFunction ] ] = Field (
default = None ,
description = " A list of functions to apply to the generated completions. " ,
)
function_call : Optional [ Union [ str , llama_cpp . ChatCompletionFunctionCall ] ] = Field (
default = None ,
description = " A function to apply to the generated completions. " ,
)
2023-04-30 01:37:43 +00:00
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-06-06 02:37:11 +00:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
2023-04-30 01:37:43 +00:00
stop : Optional [ List [ str ] ] = stop_field
stream : bool = stream_field
2023-05-09 23:19:46 +00:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 17:13:08 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-04-29 05:43:37 +00:00
2023-05-07 06:00:22 +00:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
user : Optional [ str ] = Field ( None )
2023-04-29 05:43:37 +00:00
# llama.cpp specific parameters
2023-04-30 01:37:43 +00:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-15 02:08:28 +00:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-04-29 05:43:37 +00:00
2023-07-14 03:25:12 +00:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" messages " : [
ChatCompletionRequestMessage (
role = " system " , content = " You are a helpful assistant. "
) . model_dump ( ) ,
ChatCompletionRequestMessage (
role = " user " , content = " What is the capital of France? "
) . model_dump ( ) ,
]
}
]
2023-04-29 05:43:37 +00:00
}
2023-07-14 03:25:12 +00:00
}
2023-04-29 05:43:37 +00:00
2023-05-02 02:38:46 +00:00
@router.post (
2023-04-29 05:43:37 +00:00
" /v1/chat/completions " ,
)
2023-05-27 13:12:58 +00:00
async def create_chat_completion (
request : Request ,
body : CreateChatCompletionRequest ,
2023-04-29 05:43:37 +00:00
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-07 07:37:23 +00:00
settings : Settings = Depends ( get_settings ) ,
2023-07-14 03:25:12 +00:00
) - > llama_cpp . ChatCompletion :
2023-05-27 13:12:58 +00:00
exclude = {
" n " ,
" logit_bias " ,
2023-06-09 17:13:08 +00:00
" logit_bias_type " ,
2023-05-27 13:12:58 +00:00
" user " ,
}
2023-07-14 03:25:12 +00:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 17:13:08 +00:00
if body . logit_bias is not None :
2023-07-19 07:48:27 +00:00
kwargs [ " logits_processor " ] = llama_cpp . LogitsProcessorList (
[
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
]
)
2023-06-09 17:13:08 +00:00
2023-09-14 01:23:23 +00:00
iterator_or_completion : Union [
llama_cpp . ChatCompletion , Iterator [ llama_cpp . ChatCompletionChunk ]
] = await run_in_threadpool ( llama . create_chat_completion , * * kwargs )
2023-05-27 13:12:58 +00:00
2023-07-16 05:57:39 +00:00
if isinstance ( iterator_or_completion , Iterator ) :
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool ( next , iterator_or_completion )
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator ( ) - > Iterator [ llama_cpp . ChatCompletionChunk ] :
yield first_response
yield from iterator_or_completion
2023-04-29 05:43:37 +00:00
2023-07-16 05:57:39 +00:00
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
2023-04-29 05:43:37 +00:00
return EventSourceResponse (
2023-09-14 01:23:23 +00:00
recv_chan ,
data_sender_callable = partial ( # type: ignore
2023-07-16 05:57:39 +00:00
get_event_publisher ,
request = request ,
inner_send_chan = send_chan ,
iterator = iterator ( ) ,
2023-09-14 01:23:23 +00:00
) ,
2023-04-29 05:43:37 +00:00
)
2023-07-16 05:57:39 +00:00
else :
return iterator_or_completion
2023-04-29 05:43:37 +00:00
class ModelData ( TypedDict ) :
id : str
object : Literal [ " model " ]
owned_by : str
permissions : List [ str ]
class ModelList ( TypedDict ) :
object : Literal [ " list " ]
data : List [ ModelData ]
2023-07-08 01:38:46 +00:00
@router.get ( " /v1/models " )
2023-05-27 13:12:58 +00:00
async def get_models (
2023-05-16 21:22:00 +00:00
settings : Settings = Depends ( get_settings ) ,
2023-05-08 00:17:52 +00:00
) - > ModelList :
2023-07-07 07:04:17 +00:00
assert llama is not None
2023-04-29 05:43:37 +00:00
return {
" object " : " list " ,
" data " : [
{
2023-05-16 21:22:00 +00:00
" id " : settings . model_alias
if settings . model_alias is not None
else llama . model_path ,
2023-04-29 05:43:37 +00:00
" object " : " model " ,
" owned_by " : " me " ,
" permissions " : [ ] ,
}
] ,
}