Add support for chat completion
This commit is contained in:
parent
3dec778c90
commit
7fedf16531
3 changed files with 211 additions and 4 deletions
|
@ -1,7 +1,18 @@
|
||||||
"""Example FastAPI server for llama.cpp.
|
"""Example FastAPI server for llama.cpp.
|
||||||
|
|
||||||
|
To run this example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install fastapi uvicorn sse-starlette
|
||||||
|
export MODEL=../models/7B/...
|
||||||
|
uvicorn fastapi_server_chat:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional, Iterator
|
from typing import List, Optional, Literal, Union, Iterator
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
|
@ -95,4 +106,67 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||||
response_model=CreateEmbeddingResponse,
|
response_model=CreateEmbeddingResponse,
|
||||||
)
|
)
|
||||||
def create_embedding(request: CreateEmbeddingRequest):
|
def create_embedding(request: CreateEmbeddingRequest):
|
||||||
return llama.create_embedding(request.input)
|
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequestMessage(BaseModel):
|
||||||
|
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
|
||||||
|
content: str
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChatCompletionRequest(BaseModel):
|
||||||
|
model: Optional[str]
|
||||||
|
messages: List[ChatCompletionRequestMessage]
|
||||||
|
temperature: float = 0.8
|
||||||
|
top_p: float = 0.95
|
||||||
|
stream: bool = False
|
||||||
|
stop: List[str] = []
|
||||||
|
max_tokens: int = 128
|
||||||
|
repeat_penalty: float = 1.1
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"messages": [
|
||||||
|
ChatCompletionRequestMessage(
|
||||||
|
role="system", content="You are a helpful assistant."
|
||||||
|
),
|
||||||
|
ChatCompletionRequestMessage(
|
||||||
|
role="user", content="What is the capital of France?"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
response_model=CreateChatCompletionResponse,
|
||||||
|
)
|
||||||
|
async def create_chat_completion(
|
||||||
|
request: CreateChatCompletionRequest,
|
||||||
|
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
|
||||||
|
completion_or_chunks = llama.create_chat_completion(
|
||||||
|
**request.dict(exclude={"model"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
|
||||||
|
async def server_sent_events(
|
||||||
|
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
|
||||||
|
):
|
||||||
|
for chat_chunk in chat_chunks:
|
||||||
|
yield dict(data=json.dumps(chat_chunk))
|
||||||
|
yield dict(data="[DONE]")
|
||||||
|
|
||||||
|
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|
||||||
|
return EventSourceResponse(
|
||||||
|
server_sent_events(chunks),
|
||||||
|
)
|
||||||
|
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
|
||||||
|
return completion
|
||||||
|
|
|
@ -517,6 +517,99 @@ class Llama:
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _convert_text_completion_to_chat(
|
||||||
|
self, completion: Completion
|
||||||
|
) -> ChatCompletion:
|
||||||
|
return {
|
||||||
|
"id": "chat" + completion["id"],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": completion["created"],
|
||||||
|
"model": completion["model"],
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": completion["choices"][0]["text"],
|
||||||
|
},
|
||||||
|
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": completion["usage"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_text_completion_chunks_to_chat(
|
||||||
|
self,
|
||||||
|
chunks: Iterator[CompletionChunk],
|
||||||
|
) -> Iterator[ChatCompletionChunk]:
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if i == 0:
|
||||||
|
yield {
|
||||||
|
"id": "chat" + chunk["id"],
|
||||||
|
"model": chunk["model"],
|
||||||
|
"created": chunk["created"],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
"id": "chat" + chunk["id"],
|
||||||
|
"model": chunk["model"],
|
||||||
|
"created": chunk["created"],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": chunk["choices"][0]["text"],
|
||||||
|
},
|
||||||
|
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[ChatCompletionMessage],
|
||||||
|
temperature: float = 0.8,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
top_k: int = 40,
|
||||||
|
stream: bool = False,
|
||||||
|
stop: List[str] = [],
|
||||||
|
max_tokens: int = 128,
|
||||||
|
repeat_penalty: float = 1.1,
|
||||||
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
|
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
|
||||||
|
chat_history = "\n".join(
|
||||||
|
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
|
||||||
|
for message in messages
|
||||||
|
)
|
||||||
|
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
|
||||||
|
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
|
||||||
|
completion_or_chunks = self(
|
||||||
|
prompt=PROMPT,
|
||||||
|
stop=PROMPT_STOP + stop,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream=stream,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
return self._convert_text_completion_chunks_to_chat(chunks)
|
||||||
|
else:
|
||||||
|
completion: Completion = completion_or_chunks # type: ignore
|
||||||
|
return self._convert_text_completion_to_chat(completion)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.ctx is not None:
|
if self.ctx is not None:
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import List, Optional, Dict, Literal
|
from typing import List, Optional, Dict, Literal, Union
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingUsage(TypedDict):
|
class EmbeddingUsage(TypedDict):
|
||||||
|
@ -55,3 +55,43 @@ class Completion(TypedDict):
|
||||||
model: str
|
model: str
|
||||||
choices: List[CompletionChoice]
|
choices: List[CompletionChoice]
|
||||||
usage: CompletionUsage
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessage(TypedDict):
|
||||||
|
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
|
||||||
|
content: str
|
||||||
|
user: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChoice(TypedDict):
|
||||||
|
index: int
|
||||||
|
message: ChatCompletionMessage
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletion(TypedDict):
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion"]
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionChoice]
|
||||||
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunkDelta(TypedDict):
|
||||||
|
role: NotRequired[Literal["assistant"]]
|
||||||
|
content: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunkChoice(TypedDict):
|
||||||
|
index: int
|
||||||
|
delta: ChatCompletionChunkDelta
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(TypedDict):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
object: Literal["chat.completion.chunk"]
|
||||||
|
created: int
|
||||||
|
choices: List[ChatCompletionChunkChoice]
|
||||||
|
|
Loading…
Add table
Reference in a new issue