misc: Format
This commit is contained in:
parent
0d37ce52b1
commit
727d60c28a
5 changed files with 44 additions and 39 deletions
|
@ -200,7 +200,7 @@ async def authenticate(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
summary="Completion",
|
summary="Completion",
|
||||||
dependencies=[Depends(authenticate)],
|
dependencies=[Depends(authenticate)],
|
||||||
response_model= Union[
|
response_model=Union[
|
||||||
llama_cpp.CreateCompletionResponse,
|
llama_cpp.CreateCompletionResponse,
|
||||||
str,
|
str,
|
||||||
],
|
],
|
||||||
|
@ -216,14 +216,14 @@ async def authenticate(
|
||||||
"title": "Completion response, when stream=False",
|
"title": "Completion response, when stream=False",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"text/event-stream":{
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"title": "Server Side Streaming response, when stream=True. " +
|
"title": "Server Side Streaming response, when stream=True. "
|
||||||
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
||||||
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
|
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -290,7 +290,7 @@ async def create_completion(
|
||||||
inner_send_chan=send_chan,
|
inner_send_chan=send_chan,
|
||||||
iterator=iterator(),
|
iterator=iterator(),
|
||||||
),
|
),
|
||||||
sep='\n',
|
sep="\n",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return iterator_or_completion
|
return iterator_or_completion
|
||||||
|
@ -310,10 +310,10 @@ async def create_embedding(
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/chat/completions", summary="Chat", dependencies=[Depends(authenticate)],
|
"/v1/chat/completions",
|
||||||
response_model= Union[
|
summary="Chat",
|
||||||
llama_cpp.ChatCompletion, str
|
dependencies=[Depends(authenticate)],
|
||||||
],
|
response_model=Union[llama_cpp.ChatCompletion, str],
|
||||||
responses={
|
responses={
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Successful Response",
|
"description": "Successful Response",
|
||||||
|
@ -321,19 +321,21 @@ async def create_embedding(
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{"$ref": "#/components/schemas/CreateChatCompletionResponse"}
|
{
|
||||||
|
"$ref": "#/components/schemas/CreateChatCompletionResponse"
|
||||||
|
}
|
||||||
],
|
],
|
||||||
"title": "Completion response, when stream=False",
|
"title": "Completion response, when stream=False",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"text/event-stream":{
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"title": "Server Side Streaming response, when stream=True" +
|
"title": "Server Side Streaming response, when stream=True"
|
||||||
"See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
||||||
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]"""
|
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -383,7 +385,7 @@ async def create_chat_completion(
|
||||||
inner_send_chan=send_chan,
|
inner_send_chan=send_chan,
|
||||||
iterator=iterator(),
|
iterator=iterator(),
|
||||||
),
|
),
|
||||||
sep='\n',
|
sep="\n",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return iterator_or_completion
|
return iterator_or_completion
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_cpp.server.types import (
|
||||||
CreateChatCompletionRequest,
|
CreateChatCompletionRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(TypedDict):
|
class ErrorResponse(TypedDict):
|
||||||
"""OpenAI style error response"""
|
"""OpenAI style error response"""
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ class ErrorResponseFormatters:
|
||||||
(completion_tokens or 0) + prompt_tokens,
|
(completion_tokens or 0) + prompt_tokens,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
type="invalid_request_error",
|
type="invalid_request_error",
|
||||||
param="messages",
|
param="messages",
|
||||||
code="context_length_exceeded",
|
code="context_length_exceeded",
|
||||||
|
@ -207,4 +208,3 @@ class RouteErrorHandler(APIRoute):
|
||||||
)
|
)
|
||||||
|
|
||||||
return custom_route_handler
|
return custom_route_handler
|
||||||
|
|
||||||
|
|
|
@ -88,15 +88,15 @@ class LlamaProxy:
|
||||||
assert (
|
assert (
|
||||||
settings.hf_tokenizer_config_path is not None
|
settings.hf_tokenizer_config_path is not None
|
||||||
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
||||||
chat_handler = (
|
chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
|
||||||
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
|
json.load(open(settings.hf_tokenizer_config_path))
|
||||||
json.load(open(settings.hf_tokenizer_config_path))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
|
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
|
||||||
if settings.hf_pretrained_model_name_or_path is not None:
|
if settings.hf_pretrained_model_name_or_path is not None:
|
||||||
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path)
|
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
|
||||||
|
settings.hf_pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
draft_model = None
|
draft_model = None
|
||||||
if settings.draft_model is not None:
|
if settings.draft_model is not None:
|
||||||
|
@ -126,12 +126,15 @@ class LlamaProxy:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if settings.hf_model_repo_id is not None:
|
if settings.hf_model_repo_id is not None:
|
||||||
create_fn = functools.partial(llama_cpp.Llama.from_pretrained, repo_id=settings.hf_model_repo_id, filename=settings.model)
|
create_fn = functools.partial(
|
||||||
|
llama_cpp.Llama.from_pretrained,
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
create_fn = llama_cpp.Llama
|
create_fn = llama_cpp.Llama
|
||||||
kwargs["model_path"] = settings.model
|
kwargs["model_path"] = settings.model
|
||||||
|
|
||||||
|
|
||||||
_model = create_fn(
|
_model = create_fn(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
# Model Params
|
# Model Params
|
||||||
|
|
|
@ -74,7 +74,9 @@ class ModelSettings(BaseSettings):
|
||||||
ge=0,
|
ge=0,
|
||||||
description="The number of threads to use when batch processing.",
|
description="The number of threads to use when batch processing.",
|
||||||
)
|
)
|
||||||
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED)
|
rope_scaling_type: int = Field(
|
||||||
|
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||||
|
)
|
||||||
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
|
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
|
||||||
rope_freq_scale: float = Field(
|
rope_freq_scale: float = Field(
|
||||||
default=0.0, description="RoPE frequency scaling factor"
|
default=0.0, description="RoPE frequency scaling factor"
|
||||||
|
@ -193,6 +195,4 @@ class Settings(ServerSettings, ModelSettings):
|
||||||
class ConfigFileSettings(ServerSettings):
|
class ConfigFileSettings(ServerSettings):
|
||||||
"""Configuration file format settings."""
|
"""Configuration file format settings."""
|
||||||
|
|
||||||
models: List[ModelSettings] = Field(
|
models: List[ModelSettings] = Field(default=[], description="Model configs")
|
||||||
default=[], description="Model configs"
|
|
||||||
)
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
|
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
|
||||||
)
|
)
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=16, ge=0, description="The maximum number of tokens to generate."
|
default=16, ge=0, description="The maximum number of tokens to generate."
|
||||||
)
|
)
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
|
|
Loading…
Reference in a new issue