From 727d60c28a76312c3cbed412f290af8cad2f92ec Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 28 Feb 2024 14:27:40 -0500 Subject: [PATCH] misc: Format --- llama_cpp/server/app.py | 50 +++++++++++++++++++----------------- llama_cpp/server/errors.py | 4 +-- llama_cpp/server/model.py | 19 ++++++++------ llama_cpp/server/settings.py | 8 +++--- llama_cpp/server/types.py | 2 +- 5 files changed, 44 insertions(+), 39 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 7a1391d..ec92809 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -199,8 +199,8 @@ async def authenticate( @router.post( "/v1/completions", summary="Completion", - dependencies=[Depends(authenticate)], - response_model= Union[ + dependencies=[Depends(authenticate)], + response_model=Union[ llama_cpp.CreateCompletionResponse, str, ], @@ -211,19 +211,19 @@ async def authenticate( "application/json": { "schema": { "anyOf": [ - {"$ref": "#/components/schemas/CreateCompletionResponse"} + {"$ref": "#/components/schemas/CreateCompletionResponse"} ], "title": "Completion response, when stream=False", } }, - "text/event-stream":{ - "schema": { - "type": "string", - "title": "Server Side Streaming response, when stream=True. " + - "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 - "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" + "text/event-stream": { + "schema": { + "type": "string", + "title": "Server Side Streaming response, when stream=True. " + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", } - } + }, }, } }, @@ -290,7 +290,7 @@ async def create_completion( inner_send_chan=send_chan, iterator=iterator(), ), - sep='\n', + sep="\n", ) else: return iterator_or_completion @@ -310,10 +310,10 @@ async def create_embedding( @router.post( - "/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)], - response_model= Union[ - llama_cpp.ChatCompletion, str - ], + "/v1/chat/completions", + summary="Chat", + dependencies=[Depends(authenticate)], + response_model=Union[llama_cpp.ChatCompletion, str], responses={ "200": { "description": "Successful Response", @@ -321,19 +321,21 @@ async def create_embedding( "application/json": { "schema": { "anyOf": [ - {"$ref": "#/components/schemas/CreateChatCompletionResponse"} + { + "$ref": "#/components/schemas/CreateChatCompletionResponse" + } ], "title": "Completion response, when stream=False", } }, - "text/event-stream":{ - "schema": { - "type": "string", - "title": "Server Side Streaming response, when stream=True" + - "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 - "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" + "text/event-stream": { + "schema": { + "type": "string", + "title": "Server Side Streaming response, when stream=True" + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", } - } + }, }, } }, @@ -383,7 +385,7 @@ async def create_chat_completion( inner_send_chan=send_chan, iterator=iterator(), ), - sep='\n', + sep="\n", ) else: return iterator_or_completion diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py index 9d3d355..fbf9fd8 100644 --- a/llama_cpp/server/errors.py +++ b/llama_cpp/server/errors.py @@ -22,6 +22,7 @@ from llama_cpp.server.types import ( CreateChatCompletionRequest, ) + class ErrorResponse(TypedDict): """OpenAI style error response""" @@ -75,7 +76,7 @@ class ErrorResponseFormatters: (completion_tokens or 0) + prompt_tokens, prompt_tokens, completion_tokens, - ), # type: ignore + ), # type: ignore type="invalid_request_error", param="messages", code="context_length_exceeded", @@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute): ) return custom_route_handler - diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index 816d089..dace8d5 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -88,15 +88,15 @@ class LlamaProxy: assert ( settings.hf_tokenizer_config_path is not None ), "hf_tokenizer_config_path must be set for hf-tokenizer-config" - chat_handler = ( - llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler( - json.load(open(settings.hf_tokenizer_config_path)) - ) + chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler( + json.load(open(settings.hf_tokenizer_config_path)) ) tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None if settings.hf_pretrained_model_name_or_path is not None: - tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path) + tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained( + settings.hf_pretrained_model_name_or_path + ) draft_model = None if settings.draft_model is not None: @@ -120,17 +120,20 @@ class LlamaProxy: kv_overrides[key] = float(value) else: raise ValueError(f"Unknown value type {value_type}") - + import functools kwargs = {} if settings.hf_model_repo_id is not None: - create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model) + create_fn = functools.partial( + llama_cpp.Llama.from_pretrained, + repo_id=settings.hf_model_repo_id, + filename=settings.model, + ) else: create_fn = llama_cpp.Llama kwargs["model_path"] = settings.model - _model = create_fn( **kwargs, diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 292d7eb..daa913f 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -74,7 +74,9 @@ class ModelSettings(BaseSettings): ge=0, description="The number of threads to use when batch processing.", ) - rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) + rope_scaling_type: int = Field( + default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + ) rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") rope_freq_scale: float = Field( default=0.0, description="RoPE frequency scaling factor" @@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings): class ConfigFileSettings(ServerSettings): """Configuration file format settings.""" - models: List[ModelSettings] = Field( - default=[], description="Model configs" - ) + models: List[ModelSettings] = Field(default=[], description="Model configs") diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index f0827d7..9a4b81e 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -110,7 +110,7 @@ class CreateCompletionRequest(BaseModel): default=None, description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", ) - max_tokens: Optional[int] = Field( + max_tokens: Optional[int] = Field( default=16, ge=0, description="The maximum number of tokens to generate." ) temperature: float = temperature_field