Add support for chat completion
This commit is contained in:
parent
3dec778c90
commit
7fedf16531
3 changed files with 211 additions and 4 deletions
|
@ -1,7 +1,18 @@
|
|||
"""Example FastAPI server for llama.cpp.
|
||||
|
||||
To run this example:
|
||||
|
||||
```bash
|
||||
pip install fastapi uvicorn sse-starlette
|
||||
export MODEL=../models/7B/...
|
||||
uvicorn fastapi_server_chat:app --reload
|
||||
```
|
||||
|
||||
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
"""
|
||||
import json
|
||||
from typing import List, Optional, Iterator
|
||||
from typing import List, Optional, Literal, Union, Iterator
|
||||
|
||||
import llama_cpp
|
||||
|
||||
|
@ -95,4 +106,67 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
|||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
def create_embedding(request: CreateEmbeddingRequest):
|
||||
return llama.create_embedding(request.input)
|
||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
|
||||
content: str
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
model: Optional[str]
|
||||
messages: List[ChatCompletionRequestMessage]
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
stream: bool = False
|
||||
stop: List[str] = []
|
||||
max_tokens: int = 128
|
||||
repeat_penalty: float = 1.1
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
ChatCompletionRequestMessage(
|
||||
role="system", content="You are a helpful assistant."
|
||||
),
|
||||
ChatCompletionRequestMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=CreateChatCompletionResponse,
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
|
||||
completion_or_chunks = llama.create_chat_completion(
|
||||
**request.dict(exclude={"model"}),
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
|
||||
async def server_sent_events(
|
||||
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
|
||||
):
|
||||
for chat_chunk in chat_chunks:
|
||||
yield dict(data=json.dumps(chat_chunk))
|
||||
yield dict(data="[DONE]")
|
||||
|
||||
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
|
||||
|
||||
return EventSourceResponse(
|
||||
server_sent_events(chunks),
|
||||
)
|
||||
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
|
||||
return completion
|
||||
|
|
|
@ -517,6 +517,99 @@ class Llama:
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
self, completion: Completion
|
||||
) -> ChatCompletion:
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
"created": completion["created"],
|
||||
"model": completion["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
"usage": completion["usage"],
|
||||
}
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
self,
|
||||
chunks: Iterator[CompletionChunk],
|
||||
) -> Iterator[ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: List[ChatCompletionMessage],
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: List[str] = [],
|
||||
max_tokens: int = 128,
|
||||
repeat_penalty: float = 1.1,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
|
||||
chat_history = "\n".join(
|
||||
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
|
||||
for message in messages
|
||||
)
|
||||
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
|
||||
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
|
||||
completion_or_chunks = self(
|
||||
prompt=PROMPT,
|
||||
stop=PROMPT_STOP + stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: Completion = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_to_chat(completion)
|
||||
|
||||
def __del__(self):
|
||||
if self.ctx is not None:
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List, Optional, Dict, Literal
|
||||
from typing_extensions import TypedDict
|
||||
from typing import List, Optional, Dict, Literal, Union
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
|
||||
class EmbeddingUsage(TypedDict):
|
||||
|
@ -55,3 +55,43 @@ class Completion(TypedDict):
|
|||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionMessage(TypedDict):
|
||||
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
|
||||
content: str
|
||||
user: NotRequired[str]
|
||||
|
||||
|
||||
class ChatCompletionChoice(TypedDict):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletion(TypedDict):
|
||||
id: str
|
||||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionChunkDelta(TypedDict):
|
||||
role: NotRequired[Literal["assistant"]]
|
||||
content: NotRequired[str]
|
||||
|
||||
|
||||
class ChatCompletionChunkChoice(TypedDict):
|
||||
index: int
|
||||
delta: ChatCompletionChunkDelta
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletionChunk(TypedDict):
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: int
|
||||
choices: List[ChatCompletionChunkChoice]
|
||||
|
|
Loading…
Reference in a new issue