Add support for chat completion

This commit is contained in:
Andrei Betlen 2023-04-03 20:12:44 -04:00
parent 3dec778c90
commit 7fedf16531
3 changed files with 211 additions and 4 deletions

View file

@ -1,7 +1,18 @@
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/...
uvicorn fastapi_server_chat:app --reload
```
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import json
from typing import List, Optional, Iterator
from typing import List, Optional, Literal, Union, Iterator
import llama_cpp
@ -95,4 +106,67 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
return llama.create_embedding(request.input)
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
class ChatCompletionRequestMessage(BaseModel):
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
content: str
user: Optional[str] = None
class CreateChatCompletionRequest(BaseModel):
model: Optional[str]
messages: List[ChatCompletionRequestMessage]
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: List[str] = []
max_tokens: int = 128
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
async def create_chat_completion(
request: CreateChatCompletionRequest,
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(exclude={"model"}),
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(
server_sent_events(chunks),
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion

View file

@ -517,6 +517,99 @@ class Llama:
stream=stream,
)
def _convert_text_completion_to_chat(
self, completion: Completion
) -> ChatCompletion:
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion["choices"][0]["text"],
},
"finish_reason": completion["choices"][0]["finish_reason"],
}
],
"usage": completion["usage"],
}
def _convert_text_completion_chunks_to_chat(
self,
chunks: Iterator[CompletionChunk],
) -> Iterator[ChatCompletionChunk]:
for i, chunk in enumerate(chunks):
if i == 0:
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
],
}
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"content": chunk["choices"][0]["text"],
},
"finish_reason": chunk["choices"][0]["finish_reason"],
}
],
}
def create_chat_completion(
self,
messages: List[ChatCompletionMessage],
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: List[str] = [],
max_tokens: int = 128,
repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
chat_history = "\n".join(
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
for message in messages
)
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
completion_or_chunks = self(
prompt=PROMPT,
stop=PROMPT_STOP + stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
max_tokens=max_tokens,
repeat_penalty=repeat_penalty,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)
def __del__(self):
if self.ctx is not None:
llama_cpp.llama_free(self.ctx)

View file

@ -1,5 +1,5 @@
from typing import List, Optional, Dict, Literal
from typing_extensions import TypedDict
from typing import List, Optional, Dict, Literal, Union
from typing_extensions import TypedDict, NotRequired
class EmbeddingUsage(TypedDict):
@ -55,3 +55,43 @@ class Completion(TypedDict):
model: str
choices: List[CompletionChoice]
usage: CompletionUsage
class ChatCompletionMessage(TypedDict):
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
content: str
user: NotRequired[str]
class ChatCompletionChoice(TypedDict):
index: int
message: ChatCompletionMessage
finish_reason: Optional[str]
class ChatCompletion(TypedDict):
id: str
object: Literal["chat.completion"]
created: int
model: str
choices: List[ChatCompletionChoice]
usage: CompletionUsage
class ChatCompletionChunkDelta(TypedDict):
role: NotRequired[Literal["assistant"]]
content: NotRequired[str]
class ChatCompletionChunkChoice(TypedDict):
index: int
delta: ChatCompletionChunkDelta
finish_reason: Optional[str]
class ChatCompletionChunk(TypedDict):
id: str
model: str
object: Literal["chat.completion.chunk"]
created: int
choices: List[ChatCompletionChunkChoice]