misc: Format

This commit is contained in:
Andrei Betlen 2024-02-28 14:27:40 -05:00
parent 0d37ce52b1
commit 727d60c28a
5 changed files with 44 additions and 39 deletions

View file

@ -199,8 +199,8 @@ async def authenticate(
@router.post( @router.post(
"/v1/completions", "/v1/completions",
summary="Completion", summary="Completion",
dependencies=[Depends(authenticate)], dependencies=[Depends(authenticate)],
response_model= Union[ response_model=Union[
llama_cpp.CreateCompletionResponse, llama_cpp.CreateCompletionResponse,
str, str,
], ],
@ -211,19 +211,19 @@ async def authenticate(
"application/json": { "application/json": {
"schema": { "schema": {
"anyOf": [ "anyOf": [
{"$ref": "#/components/schemas/CreateCompletionResponse"} {"$ref": "#/components/schemas/CreateCompletionResponse"}
], ],
"title": "Completion response, when stream=False", "title": "Completion response, when stream=False",
} }
}, },
"text/event-stream":{ "text/event-stream": {
"schema": { "schema": {
"type": "string", "type": "string",
"title": "Server Side Streaming response, when stream=True. " + "title": "Server Side Streaming response, when stream=True. "
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
} }
} },
}, },
} }
}, },
@ -290,7 +290,7 @@ async def create_completion(
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
), ),
sep='\n', sep="\n",
) )
else: else:
return iterator_or_completion return iterator_or_completion
@ -310,10 +310,10 @@ async def create_embedding(
@router.post( @router.post(
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)], "/v1/chat/completions",
response_model= Union[ summary="Chat",
llama_cpp.ChatCompletion, str dependencies=[Depends(authenticate)],
], response_model=Union[llama_cpp.ChatCompletion, str],
responses={ responses={
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
@ -321,19 +321,21 @@ async def create_embedding(
"application/json": { "application/json": {
"schema": { "schema": {
"anyOf": [ "anyOf": [
{"$ref": "#/components/schemas/CreateChatCompletionResponse"} {
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
], ],
"title": "Completion response, when stream=False", "title": "Completion response, when stream=False",
} }
}, },
"text/event-stream":{ "text/event-stream": {
"schema": { "schema": {
"type": "string", "type": "string",
"title": "Server Side Streaming response, when stream=True" + "title": "Server Side Streaming response, when stream=True"
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""" "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
} }
} },
}, },
} }
}, },
@ -383,7 +385,7 @@ async def create_chat_completion(
inner_send_chan=send_chan, inner_send_chan=send_chan,
iterator=iterator(), iterator=iterator(),
), ),
sep='\n', sep="\n",
) )
else: else:
return iterator_or_completion return iterator_or_completion

View file

@ -22,6 +22,7 @@ from llama_cpp.server.types import (
CreateChatCompletionRequest, CreateChatCompletionRequest,
) )
class ErrorResponse(TypedDict): class ErrorResponse(TypedDict):
"""OpenAI style error response""" """OpenAI style error response"""
@ -75,7 +76,7 @@ class ErrorResponseFormatters:
(completion_tokens or 0) + prompt_tokens, (completion_tokens or 0) + prompt_tokens,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
), # type: ignore ), # type: ignore
type="invalid_request_error", type="invalid_request_error",
param="messages", param="messages",
code="context_length_exceeded", code="context_length_exceeded",
@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute):
) )
return custom_route_handler return custom_route_handler

View file

@ -88,15 +88,15 @@ class LlamaProxy:
assert ( assert (
settings.hf_tokenizer_config_path is not None settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config" ), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = ( chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler( json.load(open(settings.hf_tokenizer_config_path))
json.load(open(settings.hf_tokenizer_config_path))
)
) )
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None: if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path) tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
settings.hf_pretrained_model_name_or_path
)
draft_model = None draft_model = None
if settings.draft_model is not None: if settings.draft_model is not None:
@ -120,17 +120,20 @@ class LlamaProxy:
kv_overrides[key] = float(value) kv_overrides[key] = float(value)
else: else:
raise ValueError(f"Unknown value type {value_type}") raise ValueError(f"Unknown value type {value_type}")
import functools import functools
kwargs = {} kwargs = {}
if settings.hf_model_repo_id is not None: if settings.hf_model_repo_id is not None:
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model) create_fn = functools.partial(
llama_cpp.Llama.from_pretrained,
repo_id=settings.hf_model_repo_id,
filename=settings.model,
)
else: else:
create_fn = llama_cpp.Llama create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model kwargs["model_path"] = settings.model
_model = create_fn( _model = create_fn(
**kwargs, **kwargs,

View file

@ -74,7 +74,9 @@ class ModelSettings(BaseSettings):
ge=0, ge=0,
description="The number of threads to use when batch processing.", description="The number of threads to use when batch processing.",
) )
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency") rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field( rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor" default=0.0, description="RoPE frequency scaling factor"
@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings):
class ConfigFileSettings(ServerSettings): class ConfigFileSettings(ServerSettings):
"""Configuration file format settings.""" """Configuration file format settings."""
models: List[ModelSettings] = Field( models: List[ModelSettings] = Field(default=[], description="Model configs")
default=[], description="Model configs"
)

View file

@ -110,7 +110,7 @@ class CreateCompletionRequest(BaseModel):
default=None, default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
) )
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate." default=16, ge=0, description="The maximum number of tokens to generate."
) )
temperature: float = temperature_field temperature: float = temperature_field