2023-04-28 22:43:37 -07:00
import json
2023-05-07 03:03:57 -04:00
import multiprocessing
2023-04-28 22:43:37 -07:00
from threading import Lock
2023-05-01 15:11:15 -04:00
from typing import List , Optional , Union , Iterator , Dict
2023-05-07 03:03:57 -04:00
from typing_extensions import TypedDict , Literal
2023-04-28 22:43:37 -07:00
import llama_cpp
2023-05-01 22:38:46 -04:00
from fastapi import Depends , FastAPI , APIRouter
2023-04-28 22:43:37 -07:00
from fastapi . middleware . cors import CORSMiddleware
from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
from sse_starlette . sse import EventSourceResponse
class Settings ( BaseSettings ) :
2023-05-07 02:52:20 -04:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
2023-05-16 17:22:00 -04:00
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
2023-05-07 02:52:20 -04:00
n_ctx : int = Field ( default = 2048 , ge = 1 , description = " The context size. " )
2023-05-14 00:04:22 -04:00
n_gpu_layers : int = Field (
default = 0 ,
ge = 0 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. " ,
)
2023-05-07 02:52:20 -04:00
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
n_threads : int = Field (
2023-05-07 03:03:57 -04:00
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
2023-05-07 02:52:20 -04:00
ge = 1 ,
description = " The number of threads to use. " ,
)
f16_kv : bool = Field ( default = True , description = " Whether to use f16 key/value. " )
use_mlock : bool = Field (
2023-05-07 03:04:22 -04:00
default = llama_cpp . llama_mlock_supported ( ) ,
2023-05-07 02:52:20 -04:00
description = " Use mlock. " ,
)
use_mmap : bool = Field (
2023-05-07 03:04:22 -04:00
default = llama_cpp . llama_mmap_supported ( ) ,
2023-05-07 02:52:20 -04:00
description = " Use mmap. " ,
)
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
2023-05-07 19:33:17 -04:00
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2023-05-07 02:52:20 -04:00
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
)
2023-05-07 05:09:10 -04:00
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
router = APIRouter ( )
2023-05-16 17:22:00 -04:00
settings : Optional [ Settings ] = None
2023-05-01 22:38:46 -04:00
llama : Optional [ llama_cpp . Llama ] = None
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
def create_app ( settings : Optional [ Settings ] = None ) :
2023-04-28 23:47:36 -07:00
if settings is None :
settings = Settings ( )
2023-05-01 22:38:46 -04:00
app = FastAPI (
title = " 🦙 llama.cpp Python API " ,
version = " 0.0.1 " ,
)
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
app . include_router ( router )
2023-04-28 23:47:36 -07:00
global llama
llama = llama_cpp . Llama (
2023-05-01 22:38:46 -04:00
model_path = settings . model ,
2023-05-14 00:04:22 -04:00
n_gpu_layers = settings . n_gpu_layers ,
2023-04-28 23:47:36 -07:00
f16_kv = settings . f16_kv ,
use_mlock = settings . use_mlock ,
use_mmap = settings . use_mmap ,
embedding = settings . embedding ,
logits_all = settings . logits_all ,
n_threads = settings . n_threads ,
n_batch = settings . n_batch ,
n_ctx = settings . n_ctx ,
last_n_tokens_size = settings . last_n_tokens_size ,
vocab_only = settings . vocab_only ,
2023-05-07 05:09:10 -04:00
verbose = settings . verbose ,
2023-04-28 23:47:36 -07:00
)
if settings . cache :
2023-05-07 19:33:17 -04:00
cache = llama_cpp . LlamaCache ( capacity_bytes = settings . cache_size )
2023-04-28 23:47:36 -07:00
llama . set_cache ( cache )
2023-05-16 17:22:00 -04:00
def set_settings ( _settings : Settings ) :
global settings
settings = _settings
set_settings ( settings )
2023-05-01 22:38:46 -04:00
return app
2023-04-28 22:43:37 -07:00
2023-04-28 23:47:36 -07:00
llama_lock = Lock ( )
2023-05-01 22:38:46 -04:00
2023-04-28 22:43:37 -07:00
def get_llama ( ) :
with llama_lock :
yield llama
2023-05-07 02:52:20 -04:00
2023-05-16 17:22:00 -04:00
def get_settings ( ) :
yield settings
2023-05-07 02:52:20 -04:00
model_field = Field ( description = " The model to use for generating completions. " )
2023-04-29 00:47:35 -07:00
2023-04-29 18:37:43 -07:00
max_tokens_field = Field (
2023-05-07 02:52:20 -04:00
default = 16 , ge = 1 , le = 2048 , description = " The maximum number of tokens to generate. "
2023-04-29 18:37:43 -07:00
)
temperature_field = Field (
default = 0.8 ,
ge = 0.0 ,
le = 2.0 ,
2023-05-07 02:52:20 -04:00
description = " Adjust the randomness of the generated text. \n \n "
+ " Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model ' s output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. " ,
2023-04-29 18:37:43 -07:00
)
top_p_field = Field (
default = 0.95 ,
ge = 0.0 ,
le = 1.0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. \n \n "
+ " Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. " ,
2023-04-29 18:37:43 -07:00
)
stop_field = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A list of tokens at which to stop generation. If None, no stop tokens are used. " ,
2023-04-29 18:37:43 -07:00
)
stream_field = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to stream the results as they are generated. Useful for chatbots. " ,
2023-04-29 18:37:43 -07:00
)
top_k_field = Field (
default = 40 ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to the K most probable tokens. \n \n "
+ " Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. " ,
2023-04-29 18:37:43 -07:00
)
repeat_penalty_field = Field (
2023-05-08 18:49:11 -04:00
default = 1.1 ,
2023-04-29 18:37:43 -07:00
ge = 0.0 ,
2023-05-07 02:52:20 -04:00
description = " A penalty applied to each token that is already generated. This helps prevent the model from repeating itself. \n \n "
+ " Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. " ,
2023-04-29 18:37:43 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on whether they appear in the text so far, increasing the model ' s likelihood to talk about new topics. " ,
)
frequency_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model ' s likelihood to repeat the same line verbatim. " ,
)
2023-05-07 02:52:20 -04:00
2023-05-12 07:21:46 -04:00
2023-04-28 22:43:37 -07:00
class CreateCompletionRequest ( BaseModel ) :
2023-05-12 07:16:57 -04:00
prompt : Union [ str , List [ str ] ] = Field (
2023-05-12 07:21:46 -04:00
default = " " , description = " The prompt to generate completions for. "
2023-04-29 14:37:36 -07:00
)
suffix : Optional [ str ] = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-04-29 18:37:43 -07:00
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-04-29 14:37:36 -07:00
echo : bool = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to echo the prompt in the generated text. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-04-29 18:37:43 -07:00
stop : Optional [ List [ str ] ] = stop_field
stream : bool = stream_field
2023-04-29 14:37:36 -07:00
logprobs : Optional [ int ] = Field (
default = None ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " The number of logprobs to generate. If None, no logprobs are generated. " ,
2023-04-29 14:37:36 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-04-29 14:37:36 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
logprobs : Optional [ int ] = Field ( None )
best_of : Optional [ int ] = 1
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
user : Optional [ str ] = Field ( None )
2023-04-29 00:47:35 -07:00
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-04-28 22:43:37 -07:00
class Config :
schema_extra = {
" example " : {
" prompt " : " \n \n ### Instructions: \n What is the capital of France? \n \n ### Response: \n " ,
" stop " : [ " \n " , " ### " ] ,
}
}
CreateCompletionResponse = create_model_from_typeddict ( llama_cpp . Completion )
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/completions " ,
response_model = CreateCompletionResponse ,
)
def create_completion (
request : CreateCompletionRequest , llama : llama_cpp . Llama = Depends ( get_llama )
) :
2023-05-12 07:16:57 -04:00
if isinstance ( request . prompt , list ) :
2023-05-12 07:21:46 -04:00
assert len ( request . prompt ) < = 1
request . prompt = request . prompt [ 0 ] if len ( request . prompt ) > 0 else " "
2023-05-12 07:16:57 -04:00
2023-04-28 22:43:37 -07:00
completion_or_chunks = llama (
* * request . dict (
exclude = {
2023-05-07 02:02:34 -04:00
" n " ,
" best_of " ,
" logit_bias " ,
" user " ,
2023-04-28 22:43:37 -07:00
}
)
)
if request . stream :
chunks : Iterator [ llama_cpp . CompletionChunk ] = completion_or_chunks # type: ignore
return EventSourceResponse ( dict ( data = json . dumps ( chunk ) ) for chunk in chunks )
completion : llama_cpp . Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest ( BaseModel ) :
2023-05-07 02:00:22 -04:00
model : Optional [ str ] = model_field
2023-05-07 02:52:20 -04:00
input : str = Field ( description = " The input to embed. " )
2023-05-07 02:00:22 -04:00
user : Optional [ str ]
2023-04-28 22:43:37 -07:00
class Config :
schema_extra = {
" example " : {
" input " : " The food was delicious and the waiter... " ,
}
}
CreateEmbeddingResponse = create_model_from_typeddict ( llama_cpp . Embedding )
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/embeddings " ,
response_model = CreateEmbeddingResponse ,
)
def create_embedding (
request : CreateEmbeddingRequest , llama : llama_cpp . Llama = Depends ( get_llama )
) :
2023-05-16 17:22:00 -04:00
return llama . create_embedding ( * * request . dict ( exclude = { " user " } ) )
2023-04-28 22:43:37 -07:00
class ChatCompletionRequestMessage ( BaseModel ) :
2023-05-01 11:48:37 -07:00
role : Literal [ " system " , " user " , " assistant " ] = Field (
default = " user " , description = " The role of the message. "
2023-04-29 18:37:43 -07:00
)
content : str = Field ( default = " " , description = " The content of the message. " )
2023-04-28 22:43:37 -07:00
class CreateChatCompletionRequest ( BaseModel ) :
2023-04-29 18:37:43 -07:00
messages : List [ ChatCompletionRequestMessage ] = Field (
2023-05-07 02:52:20 -04:00
default = [ ] , description = " A list of messages to generate completions for. "
2023-04-29 18:37:43 -07:00
)
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
stop : Optional [ List [ str ] ] = stop_field
stream : bool = stream_field
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-04-28 22:43:37 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
user : Optional [ str ] = Field ( None )
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-04-28 22:43:37 -07:00
class Config :
schema_extra = {
" example " : {
" messages " : [
ChatCompletionRequestMessage (
role = " system " , content = " You are a helpful assistant. "
) ,
ChatCompletionRequestMessage (
role = " user " , content = " What is the capital of France? "
) ,
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict ( llama_cpp . ChatCompletion )
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/chat/completions " ,
response_model = CreateChatCompletionResponse ,
)
def create_chat_completion (
request : CreateChatCompletionRequest ,
llama : llama_cpp . Llama = Depends ( get_llama ) ,
) - > Union [ llama_cpp . ChatCompletion , EventSourceResponse ] :
completion_or_chunks = llama . create_chat_completion (
* * request . dict (
exclude = {
2023-05-07 02:02:34 -04:00
" n " ,
" logit_bias " ,
" user " ,
2023-04-28 22:43:37 -07:00
}
) ,
)
if request . stream :
async def server_sent_events (
chat_chunks : Iterator [ llama_cpp . ChatCompletionChunk ] ,
) :
for chat_chunk in chat_chunks :
yield dict ( data = json . dumps ( chat_chunk ) )
yield dict ( data = " [DONE] " )
chunks : Iterator [ llama_cpp . ChatCompletionChunk ] = completion_or_chunks # type: ignore
return EventSourceResponse (
server_sent_events ( chunks ) ,
)
completion : llama_cpp . ChatCompletion = completion_or_chunks # type: ignore
return completion
class ModelData ( TypedDict ) :
id : str
object : Literal [ " model " ]
owned_by : str
permissions : List [ str ]
class ModelList ( TypedDict ) :
object : Literal [ " list " ]
data : List [ ModelData ]
GetModelResponse = create_model_from_typeddict ( ModelList )
2023-05-01 22:38:46 -04:00
@router.get ( " /v1/models " , response_model = GetModelResponse )
2023-05-07 20:17:52 -04:00
def get_models (
2023-05-16 17:22:00 -04:00
settings : Settings = Depends ( get_settings ) ,
2023-05-07 20:17:52 -04:00
llama : llama_cpp . Llama = Depends ( get_llama ) ,
) - > ModelList :
2023-04-28 22:43:37 -07:00
return {
" object " : " list " ,
" data " : [
{
2023-05-16 17:22:00 -04:00
" id " : settings . model_alias
if settings . model_alias is not None
else llama . model_path ,
2023-04-28 22:43:37 -07:00
" object " : " model " ,
" owned_by " : " me " ,
" permissions " : [ ] ,
}
] ,
}