misc: Format

This commit is contained in:
Andrei Betlen 2024-02-28 14:27:40 -05:00
parent 0d37ce52b1
commit 727d60c28a
5 changed files with 44 additions and 39 deletions

View file

@ -219,12 +219,12 @@ async def authenticate(
"text/event-stream": { "text/event-stream": {
"schema": { "schema": {
"type": "string", "type": "string",
"title": "Server Side Streaming response, when stream=True. " + "title": "Server Side Streaming response, when stream=True. "
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
} }
}, },
},
} }
}, },
) )
@ -290,7 +290,7 @@ async def create_completion(
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
), ),
sep='\n', sep="\n",
) )
else: else:
return iterator_or_completion return iterator_or_completion
@ -310,10 +310,10 @@ async def create_embedding(
@router.post( @router.post(
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)], "/v1/chat/completions",
response_model= Union[ summary="Chat",
llama_cpp.ChatCompletion, str dependencies=[Depends(authenticate)],
], response_model=Union[llama_cpp.ChatCompletion, str],
responses={ responses={
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
@ -321,7 +321,9 @@ async def create_embedding(
"application/json": { "application/json": {
"schema": { "schema": {
"anyOf": [ "anyOf": [
{"$ref": "#/components/schemas/CreateChatCompletionResponse"} {
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
], ],
"title": "Completion response, when stream=False", "title": "Completion response, when stream=False",
} }
@ -329,12 +331,12 @@ async def create_embedding(
"text/event-stream": { "text/event-stream": {
"schema": { "schema": {
"type": "string", "type": "string",
"title": "Server Side Streaming response, when stream=True" + "title": "Server Side Streaming response, when stream=True"
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
} }
}, },
},
} }
}, },
) )
@ -383,7 +385,7 @@ async def create_chat_completion(
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
), ),
sep='\n', sep="\n",
) )
else: else:
return iterator_or_completion return iterator_or_completion

View file

@ -22,6 +22,7 @@ from llama_cpp.server.types import (
CreateChatCompletionRequest, CreateChatCompletionRequest,
) )
class ErrorResponse(TypedDict): class ErrorResponse(TypedDict):
"""OpenAI style error response""" """OpenAI style error response"""
@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute):
) )
return custom_route_handler return custom_route_handler

View file

@ -88,15 +88,15 @@ class LlamaProxy:
assert ( assert (
settings.hf_tokenizer_config_path is not None settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config" ), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = ( chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path)) json.load(open(settings.hf_tokenizer_config_path))
) )
)
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None: if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path) tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
settings.hf_pretrained_model_name_or_path
)
draft_model = None draft_model = None
if settings.draft_model is not None: if settings.draft_model is not None:
@ -126,12 +126,15 @@ class LlamaProxy:
kwargs = {} kwargs = {}
if settings.hf_model_repo_id is not None: if settings.hf_model_repo_id is not None:
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model) create_fn = functools.partial(
llama_cpp.Llama.from_pretrained,
repo_id=settings.hf_model_repo_id,
filename=settings.model,
)
else: else:
create_fn = llama_cpp.Llama create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model kwargs["model_path"] = settings.model
_model = create_fn( _model = create_fn(
**kwargs, **kwargs,
# Model Params # Model Params

View file

@ -74,7 +74,9 @@ class ModelSettings(BaseSettings):
ge=0, ge=0,
description="The number of threads to use when batch processing.", description="The number of threads to use when batch processing.",
) )
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field( rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor" default=0.0, description="RoPE frequency scaling factor"
@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings):
class ConfigFileSettings(ServerSettings): class ConfigFileSettings(ServerSettings):
"""Configuration file format settings.""" """Configuration file format settings."""
models: List[ModelSettings] = Field( models: List[ModelSettings] = Field(default=[], description="Model configs")
default=[], description="Model configs"
)