2023-04-28 22:43:37 -07:00
import json
2023-05-07 03:03:57 -04:00
import multiprocessing
2023-04-28 22:43:37 -07:00
from threading import Lock
2023-05-27 09:12:58 -04:00
from functools import partial
from typing import Iterator , List , Optional , Union , Dict
2023-05-07 03:03:57 -04:00
from typing_extensions import TypedDict , Literal
2023-04-28 22:43:37 -07:00
import llama_cpp
2023-05-27 09:12:58 -04:00
import anyio
from anyio . streams . memory import MemoryObjectSendStream
from starlette . concurrency import run_in_threadpool , iterate_in_threadpool
from fastapi import Depends , FastAPI , APIRouter , Request
2023-04-28 22:43:37 -07:00
from fastapi . middleware . cors import CORSMiddleware
2023-07-07 21:38:46 -04:00
from pydantic import BaseModel , Field
from pydantic_settings import BaseSettings
2023-04-28 22:43:37 -07:00
from sse_starlette . sse import EventSourceResponse
class Settings ( BaseSettings ) :
2023-05-07 02:52:20 -04:00
model : str = Field (
description = " The path to the model to use for generating completions. "
)
2023-05-16 17:22:00 -04:00
model_alias : Optional [ str ] = Field (
default = None ,
description = " The alias of the model to use for generating completions. " ,
)
2023-05-07 02:52:20 -04:00
n_ctx : int = Field ( default = 2048 , ge = 1 , description = " The context size. " )
2023-05-14 00:04:22 -04:00
n_gpu_layers : int = Field (
default = 0 ,
ge = 0 ,
description = " The number of layers to put on the GPU. The rest will be on the CPU. " ,
)
2023-07-14 16:52:48 -04:00
tensor_split : Optional [ List [ float ] ] = Field (
2023-07-07 19:22:10 +10:00
default = None ,
description = " Split layers across multiple GPUs in proportion. " ,
)
2023-06-23 00:19:24 +04:00
seed : int = Field (
default = 1337 , description = " Random seed. -1 for random. "
)
2023-05-07 02:52:20 -04:00
n_batch : int = Field (
default = 512 , ge = 1 , description = " The batch size to use per eval. "
)
n_threads : int = Field (
2023-05-07 03:03:57 -04:00
default = max ( multiprocessing . cpu_count ( ) / / 2 , 1 ) ,
2023-05-07 02:52:20 -04:00
ge = 1 ,
description = " The number of threads to use. " ,
)
f16_kv : bool = Field ( default = True , description = " Whether to use f16 key/value. " )
use_mlock : bool = Field (
2023-05-07 03:04:22 -04:00
default = llama_cpp . llama_mlock_supported ( ) ,
2023-05-07 02:52:20 -04:00
description = " Use mlock. " ,
)
use_mmap : bool = Field (
2023-05-07 03:04:22 -04:00
default = llama_cpp . llama_mmap_supported ( ) ,
2023-05-07 02:52:20 -04:00
description = " Use mmap. " ,
)
embedding : bool = Field ( default = True , description = " Whether to use embeddings. " )
2023-06-14 22:13:42 -04:00
low_vram : bool = Field (
default = False ,
description = " Whether to use less VRAM. This will reduce performance. " ,
)
2023-05-07 02:52:20 -04:00
last_n_tokens_size : int = Field (
default = 64 ,
ge = 0 ,
description = " Last n tokens to keep for repeat penalty calculation. " ,
)
logits_all : bool = Field ( default = True , description = " Whether to return logits. " )
cache : bool = Field (
default = False ,
description = " Use a cache to reduce processing times for evaluated prompts. " ,
)
2023-06-08 13:19:23 -04:00
cache_type : Literal [ " ram " , " disk " ] = Field (
default = " ram " ,
description = " The type of cache to use. Only used if cache is True. " ,
)
2023-05-07 19:33:17 -04:00
cache_size : int = Field (
default = 2 << 30 ,
description = " The size of the cache in bytes. Only used if cache is True. " ,
)
2023-05-07 02:52:20 -04:00
vocab_only : bool = Field (
default = False , description = " Whether to only return the vocabulary. "
)
2023-05-07 05:09:10 -04:00
verbose : bool = Field (
default = True , description = " Whether to print debug information. "
)
2023-07-13 23:25:12 -04:00
host : str = Field ( default = " localhost " , description = " Listen address " )
port : int = Field ( default = 8000 , description = " Listen port " )
2023-07-07 03:37:23 -04:00
interrupt_requests : bool = Field (
default = True ,
description = " Whether to interrupt requests when a new request is received. " ,
)
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
router = APIRouter ( )
2023-05-16 17:22:00 -04:00
settings : Optional [ Settings ] = None
2023-05-01 22:38:46 -04:00
llama : Optional [ llama_cpp . Llama ] = None
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
def create_app ( settings : Optional [ Settings ] = None ) :
2023-04-28 23:47:36 -07:00
if settings is None :
settings = Settings ( )
2023-05-01 22:38:46 -04:00
app = FastAPI (
title = " 🦙 llama.cpp Python API " ,
version = " 0.0.1 " ,
)
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
app . include_router ( router )
2023-04-28 23:47:36 -07:00
global llama
llama = llama_cpp . Llama (
2023-05-01 22:38:46 -04:00
model_path = settings . model ,
2023-05-14 00:04:22 -04:00
n_gpu_layers = settings . n_gpu_layers ,
2023-07-07 19:22:10 +10:00
tensor_split = settings . tensor_split ,
2023-06-23 00:19:24 +04:00
seed = settings . seed ,
2023-04-28 23:47:36 -07:00
f16_kv = settings . f16_kv ,
use_mlock = settings . use_mlock ,
use_mmap = settings . use_mmap ,
embedding = settings . embedding ,
logits_all = settings . logits_all ,
n_threads = settings . n_threads ,
n_batch = settings . n_batch ,
n_ctx = settings . n_ctx ,
last_n_tokens_size = settings . last_n_tokens_size ,
vocab_only = settings . vocab_only ,
2023-05-07 05:09:10 -04:00
verbose = settings . verbose ,
2023-04-28 23:47:36 -07:00
)
if settings . cache :
2023-06-08 13:19:23 -04:00
if settings . cache_type == " disk " :
2023-06-14 21:46:48 -04:00
if settings . verbose :
print ( f " Using disk cache with size { settings . cache_size } " )
2023-06-08 13:19:23 -04:00
cache = llama_cpp . LlamaDiskCache ( capacity_bytes = settings . cache_size )
else :
2023-06-14 21:46:48 -04:00
if settings . verbose :
print ( f " Using ram cache with size { settings . cache_size } " )
2023-06-08 13:19:23 -04:00
cache = llama_cpp . LlamaRAMCache ( capacity_bytes = settings . cache_size )
2023-05-07 19:33:17 -04:00
cache = llama_cpp . LlamaCache ( capacity_bytes = settings . cache_size )
2023-04-28 23:47:36 -07:00
llama . set_cache ( cache )
2023-05-16 17:22:00 -04:00
def set_settings ( _settings : Settings ) :
global settings
settings = _settings
set_settings ( settings )
2023-05-01 22:38:46 -04:00
return app
2023-04-28 22:43:37 -07:00
2023-07-07 03:04:17 -04:00
llama_outer_lock = Lock ( )
llama_inner_lock = Lock ( )
2023-05-01 22:38:46 -04:00
2023-04-28 22:43:37 -07:00
def get_llama ( ) :
2023-07-07 03:04:17 -04:00
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock . acquire ( )
release_outer_lock = True
try :
llama_inner_lock . acquire ( )
try :
llama_outer_lock . release ( )
release_outer_lock = False
yield llama
finally :
llama_inner_lock . release ( )
finally :
if release_outer_lock :
llama_outer_lock . release ( )
2023-04-28 22:43:37 -07:00
2023-05-07 02:52:20 -04:00
2023-05-16 17:22:00 -04:00
def get_settings ( ) :
yield settings
2023-07-13 23:25:12 -04:00
model_field = Field ( description = " The model to use for generating completions. " , default = None )
2023-04-29 00:47:35 -07:00
2023-04-29 18:37:43 -07:00
max_tokens_field = Field (
2023-05-07 02:52:20 -04:00
default = 16 , ge = 1 , le = 2048 , description = " The maximum number of tokens to generate. "
2023-04-29 18:37:43 -07:00
)
temperature_field = Field (
default = 0.8 ,
ge = 0.0 ,
le = 2.0 ,
2023-05-07 02:52:20 -04:00
description = " Adjust the randomness of the generated text. \n \n "
+ " Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model ' s output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. " ,
2023-04-29 18:37:43 -07:00
)
top_p_field = Field (
default = 0.95 ,
ge = 0.0 ,
le = 1.0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. \n \n "
+ " Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. " ,
2023-04-29 18:37:43 -07:00
)
stop_field = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A list of tokens at which to stop generation. If None, no stop tokens are used. " ,
2023-04-29 18:37:43 -07:00
)
stream_field = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to stream the results as they are generated. Useful for chatbots. " ,
2023-04-29 18:37:43 -07:00
)
top_k_field = Field (
default = 40 ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " Limit the next token selection to the K most probable tokens. \n \n "
+ " Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. " ,
2023-04-29 18:37:43 -07:00
)
repeat_penalty_field = Field (
2023-05-08 18:49:11 -04:00
default = 1.1 ,
2023-04-29 18:37:43 -07:00
ge = 0.0 ,
2023-05-07 02:52:20 -04:00
description = " A penalty applied to each token that is already generated. This helps prevent the model from repeating itself. \n \n "
+ " Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. " ,
2023-04-29 18:37:43 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on whether they appear in the text so far, increasing the model ' s likelihood to talk about new topics. " ,
)
frequency_penalty_field = Field (
default = 0.0 ,
ge = - 2.0 ,
le = 2.0 ,
description = " Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model ' s likelihood to repeat the same line verbatim. " ,
)
2023-05-07 02:52:20 -04:00
2023-06-05 22:37:11 -04:00
mirostat_mode_field = Field (
default = 0 ,
ge = 0 ,
le = 2 ,
2023-07-13 23:25:12 -04:00
description = " Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled) " ,
2023-06-05 22:37:11 -04:00
)
mirostat_tau_field = Field (
default = 5.0 ,
ge = 0.0 ,
le = 10.0 ,
2023-07-13 23:25:12 -04:00
description = " Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text " ,
2023-06-05 22:37:11 -04:00
)
mirostat_eta_field = Field (
2023-07-13 23:25:12 -04:00
default = 0.1 , ge = 0.001 , le = 1.0 , description = " Mirostat learning rate "
2023-06-05 22:37:11 -04:00
)
2023-05-12 07:21:46 -04:00
2023-04-28 22:43:37 -07:00
class CreateCompletionRequest ( BaseModel ) :
2023-05-12 07:16:57 -04:00
prompt : Union [ str , List [ str ] ] = Field (
2023-05-12 07:21:46 -04:00
default = " " , description = " The prompt to generate completions for. "
2023-04-29 14:37:36 -07:00
)
suffix : Optional [ str ] = Field (
default = None ,
2023-05-07 02:52:20 -04:00
description = " A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-04-29 18:37:43 -07:00
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-06-05 22:37:11 -04:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
2023-04-29 14:37:36 -07:00
echo : bool = Field (
default = False ,
2023-05-07 02:52:20 -04:00
description = " Whether to echo the prompt in the generated text. Useful for chatbots. " ,
2023-04-29 14:37:36 -07:00
)
2023-05-19 03:15:08 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = stop_field
2023-04-29 18:37:43 -07:00
stream : bool = stream_field
2023-04-29 14:37:36 -07:00
logprobs : Optional [ int ] = Field (
default = None ,
ge = 0 ,
2023-05-07 02:52:20 -04:00
description = " The number of logprobs to generate. If None, no logprobs are generated. " ,
2023-04-29 14:37:36 -07:00
)
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 13:13:08 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-06-14 22:08:28 -04:00
logprobs : Optional [ int ] = Field ( None )
2023-04-29 14:37:36 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
best_of : Optional [ int ] = 1
2023-07-13 23:25:12 -04:00
user : Optional [ str ] = Field ( default = None )
2023-04-29 00:47:35 -07:00
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-14 22:08:28 -04:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-04-28 22:43:37 -07:00
2023-07-13 23:25:12 -04:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" prompt " : " \n \n ### Instructions: \n What is the capital of France? \n \n ### Response: \n " ,
" stop " : [ " \n " , " ### " ] ,
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-06-09 13:13:08 -04:00
def make_logit_bias_processor (
llama : llama_cpp . Llama ,
logit_bias : Dict [ str , float ] ,
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] ,
) :
if logit_bias_type is None :
logit_bias_type = " input_ids "
to_bias : Dict [ int , float ] = { }
if logit_bias_type == " input_ids " :
for input_id , score in logit_bias . items ( ) :
input_id = int ( input_id )
to_bias [ input_id ] = score
elif logit_bias_type == " tokens " :
for token , score in logit_bias . items ( ) :
2023-07-13 23:25:12 -04:00
token = token . encode ( " utf-8 " )
2023-06-09 13:13:08 -04:00
for input_id in llama . tokenize ( token , add_bos = False ) :
to_bias [ input_id ] = score
def logit_bias_processor (
input_ids : List [ int ] ,
scores : List [ float ] ,
) - > List [ float ] :
new_scores = [ None ] * len ( scores )
for input_id , score in enumerate ( scores ) :
new_scores [ input_id ] = score + to_bias . get ( input_id , 0.0 )
return new_scores
return logit_bias_processor
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/completions " ,
)
2023-05-27 09:12:58 -04:00
async def create_completion (
request : Request ,
body : CreateCompletionRequest ,
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-13 23:25:12 -04:00
) - > llama_cpp . Completion :
2023-05-27 09:12:58 -04:00
if isinstance ( body . prompt , list ) :
assert len ( body . prompt ) < = 1
body . prompt = body . prompt [ 0 ] if len ( body . prompt ) > 0 else " "
exclude = {
" n " ,
" best_of " ,
" logit_bias " ,
2023-06-09 13:13:08 -04:00
" logit_bias_type " ,
2023-05-27 09:12:58 -04:00
" user " ,
}
2023-07-13 23:25:12 -04:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 13:13:08 -04:00
if body . logit_bias is not None :
kwargs [ ' logits_processor ' ] = llama_cpp . LogitsProcessorList ( [
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
] )
2023-05-27 09:12:58 -04:00
if body . stream :
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
async def event_publisher ( inner_send_chan : MemoryObjectSendStream ) :
async with inner_send_chan :
try :
iterator : Iterator [ llama_cpp . CompletionChunk ] = await run_in_threadpool ( llama , * * kwargs ) # type: ignore
async for chunk in iterate_in_threadpool ( iterator ) :
await inner_send_chan . send ( dict ( data = json . dumps ( chunk ) ) )
if await request . is_disconnected ( ) :
raise anyio . get_cancelled_exc_class ( ) ( )
2023-07-07 03:37:23 -04:00
if settings . interrupt_requests and llama_outer_lock . locked ( ) :
2023-07-07 03:04:17 -04:00
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
raise anyio . get_cancelled_exc_class ( ) ( )
2023-05-27 09:12:58 -04:00
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
except anyio . get_cancelled_exc_class ( ) as e :
print ( " disconnected " )
with anyio . move_on_after ( 1 , shield = True ) :
print (
f " Disconnected from client (via refresh/close) { request . client } "
)
raise e
2023-05-19 02:04:30 -04:00
2023-05-27 09:12:58 -04:00
return EventSourceResponse (
recv_chan , data_sender_callable = partial ( event_publisher , send_chan )
2023-07-13 23:25:12 -04:00
) # type: ignore
2023-05-27 09:12:58 -04:00
else :
completion : llama_cpp . Completion = await run_in_threadpool ( llama , * * kwargs ) # type: ignore
return completion
2023-04-28 22:43:37 -07:00
class CreateEmbeddingRequest ( BaseModel ) :
2023-05-07 02:00:22 -04:00
model : Optional [ str ] = model_field
2023-05-20 01:23:32 +02:00
input : Union [ str , List [ str ] ] = Field ( description = " The input to embed. " )
2023-07-13 23:25:12 -04:00
user : Optional [ str ] = Field ( default = None )
model_config = {
" json_schema_extra " : {
" examples " : [
{
" input " : " The food was delicious and the waiter... " ,
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/embeddings " ,
)
2023-05-27 09:12:58 -04:00
async def create_embedding (
2023-04-28 22:43:37 -07:00
request : CreateEmbeddingRequest , llama : llama_cpp . Llama = Depends ( get_llama )
) :
2023-05-27 09:12:58 -04:00
return await run_in_threadpool (
2023-07-13 23:25:12 -04:00
llama . create_embedding , * * request . model_dump ( exclude = { " user " } )
2023-05-27 09:12:58 -04:00
)
2023-04-28 22:43:37 -07:00
class ChatCompletionRequestMessage ( BaseModel ) :
2023-05-01 11:48:37 -07:00
role : Literal [ " system " , " user " , " assistant " ] = Field (
default = " user " , description = " The role of the message. "
2023-04-29 18:37:43 -07:00
)
content : str = Field ( default = " " , description = " The content of the message. " )
2023-04-28 22:43:37 -07:00
class CreateChatCompletionRequest ( BaseModel ) :
2023-04-29 18:37:43 -07:00
messages : List [ ChatCompletionRequestMessage ] = Field (
2023-05-07 02:52:20 -04:00
default = [ ] , description = " A list of messages to generate completions for. "
2023-04-29 18:37:43 -07:00
)
max_tokens : int = max_tokens_field
temperature : float = temperature_field
top_p : float = top_p_field
2023-06-05 22:37:11 -04:00
mirostat_mode : int = mirostat_mode_field
mirostat_tau : float = mirostat_tau_field
mirostat_eta : float = mirostat_eta_field
2023-04-29 18:37:43 -07:00
stop : Optional [ List [ str ] ] = stop_field
stream : bool = stream_field
2023-05-09 16:19:46 -07:00
presence_penalty : Optional [ float ] = presence_penalty_field
frequency_penalty : Optional [ float ] = frequency_penalty_field
2023-06-09 13:13:08 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = Field ( None )
2023-04-28 22:43:37 -07:00
2023-05-07 02:00:22 -04:00
# ignored or currently unsupported
model : Optional [ str ] = model_field
n : Optional [ int ] = 1
user : Optional [ str ] = Field ( None )
2023-04-28 22:43:37 -07:00
# llama.cpp specific parameters
2023-04-29 18:37:43 -07:00
top_k : int = top_k_field
repeat_penalty : float = repeat_penalty_field
2023-06-14 22:08:28 -04:00
logit_bias_type : Optional [ Literal [ " input_ids " , " tokens " ] ] = Field ( None )
2023-04-28 22:43:37 -07:00
2023-07-13 23:25:12 -04:00
model_config = {
" json_schema_extra " : {
" examples " : [
{
" messages " : [
ChatCompletionRequestMessage (
role = " system " , content = " You are a helpful assistant. "
) . model_dump ( ) ,
ChatCompletionRequestMessage (
role = " user " , content = " What is the capital of France? "
) . model_dump ( ) ,
]
}
]
2023-04-28 22:43:37 -07:00
}
2023-07-13 23:25:12 -04:00
}
2023-04-28 22:43:37 -07:00
2023-05-01 22:38:46 -04:00
@router.post (
2023-04-28 22:43:37 -07:00
" /v1/chat/completions " ,
)
2023-05-27 09:12:58 -04:00
async def create_chat_completion (
request : Request ,
body : CreateChatCompletionRequest ,
2023-04-28 22:43:37 -07:00
llama : llama_cpp . Llama = Depends ( get_llama ) ,
2023-07-07 03:37:23 -04:00
settings : Settings = Depends ( get_settings ) ,
2023-07-13 23:25:12 -04:00
) - > llama_cpp . ChatCompletion :
2023-05-27 09:12:58 -04:00
exclude = {
" n " ,
" logit_bias " ,
2023-06-09 13:13:08 -04:00
" logit_bias_type " ,
2023-05-27 09:12:58 -04:00
" user " ,
}
2023-07-13 23:25:12 -04:00
kwargs = body . model_dump ( exclude = exclude )
2023-06-09 13:13:08 -04:00
if body . logit_bias is not None :
kwargs [ ' logits_processor ' ] = llama_cpp . LogitsProcessorList ( [
make_logit_bias_processor ( llama , body . logit_bias , body . logit_bias_type ) ,
] )
2023-05-27 09:12:58 -04:00
if body . stream :
send_chan , recv_chan = anyio . create_memory_object_stream ( 10 )
async def event_publisher ( inner_send_chan : MemoryObjectSendStream ) :
async with inner_send_chan :
try :
iterator : Iterator [ llama_cpp . ChatCompletionChunk ] = await run_in_threadpool ( llama . create_chat_completion , * * kwargs ) # type: ignore
async for chat_chunk in iterate_in_threadpool ( iterator ) :
await inner_send_chan . send ( dict ( data = json . dumps ( chat_chunk ) ) )
if await request . is_disconnected ( ) :
raise anyio . get_cancelled_exc_class ( ) ( )
2023-07-07 03:37:23 -04:00
if settings . interrupt_requests and llama_outer_lock . locked ( ) :
2023-07-07 03:04:17 -04:00
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
raise anyio . get_cancelled_exc_class ( ) ( )
2023-05-27 09:12:58 -04:00
await inner_send_chan . send ( dict ( data = " [DONE] " ) )
except anyio . get_cancelled_exc_class ( ) as e :
print ( " disconnected " )
with anyio . move_on_after ( 1 , shield = True ) :
print (
f " Disconnected from client (via refresh/close) { request . client } "
)
raise e
2023-04-28 22:43:37 -07:00
return EventSourceResponse (
2023-05-27 09:12:58 -04:00
recv_chan ,
data_sender_callable = partial ( event_publisher , send_chan ) ,
2023-07-13 23:25:12 -04:00
) # type: ignore
2023-05-27 09:12:58 -04:00
else :
completion : llama_cpp . ChatCompletion = await run_in_threadpool (
llama . create_chat_completion , * * kwargs # type: ignore
2023-04-28 22:43:37 -07:00
)
2023-05-27 09:12:58 -04:00
return completion
2023-04-28 22:43:37 -07:00
class ModelData ( TypedDict ) :
id : str
object : Literal [ " model " ]
owned_by : str
permissions : List [ str ]
class ModelList ( TypedDict ) :
object : Literal [ " list " ]
data : List [ ModelData ]
2023-07-07 21:38:46 -04:00
@router.get ( " /v1/models " )
2023-05-27 09:12:58 -04:00
async def get_models (
2023-05-16 17:22:00 -04:00
settings : Settings = Depends ( get_settings ) ,
2023-05-07 20:17:52 -04:00
) - > ModelList :
2023-07-07 03:04:17 -04:00
assert llama is not None
2023-04-28 22:43:37 -07:00
return {
" object " : " list " ,
" data " : [
{
2023-05-16 17:22:00 -04:00
" id " : settings . model_alias
if settings . model_alias is not None
else llama . model_path ,
2023-04-28 22:43:37 -07:00
" object " : " model " ,
" owned_by " : " me " ,
" permissions " : [ ] ,
}
] ,
}