From 7fedf1653109abf2503294a5ca58c4d817f9acd2 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 3 Apr 2023 20:12:44 -0400 Subject: [PATCH] Add support for chat completion --- examples/fastapi_server.py | 78 +++++++++++++++++++++++++++++++- llama_cpp/llama.py | 93 ++++++++++++++++++++++++++++++++++++++ llama_cpp/llama_types.py | 44 +++++++++++++++++- 3 files changed, 211 insertions(+), 4 deletions(-) diff --git a/examples/fastapi_server.py b/examples/fastapi_server.py index 728d3f7..b37f129 100644 --- a/examples/fastapi_server.py +++ b/examples/fastapi_server.py @@ -1,7 +1,18 @@ """Example FastAPI server for llama.cpp. + +To run this example: + +```bash +pip install fastapi uvicorn sse-starlette +export MODEL=../models/7B/... +uvicorn fastapi_server_chat:app --reload +``` + +Then visit http://localhost:8000/docs to see the interactive API docs. + """ import json -from typing import List, Optional, Iterator +from typing import List, Optional, Literal, Union, Iterator import llama_cpp @@ -95,4 +106,67 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) response_model=CreateEmbeddingResponse, ) def create_embedding(request: CreateEmbeddingRequest): - return llama.create_embedding(request.input) + return llama.create_embedding(**request.dict(exclude={"model", "user"})) + + +class ChatCompletionRequestMessage(BaseModel): + role: Union[Literal["system"], Literal["user"], Literal["assistant"]] + content: str + user: Optional[str] = None + + +class CreateChatCompletionRequest(BaseModel): + model: Optional[str] + messages: List[ChatCompletionRequestMessage] + temperature: float = 0.8 + top_p: float = 0.95 + stream: bool = False + stop: List[str] = [] + max_tokens: int = 128 + repeat_penalty: float = 1.1 + + class Config: + schema_extra = { + "example": { + "messages": [ + ChatCompletionRequestMessage( + role="system", content="You are a helpful assistant." + ), + ChatCompletionRequestMessage( + role="user", content="What is the capital of France?" + ), + ] + } + } + + +CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) + + +@app.post( + "/v1/chat/completions", + response_model=CreateChatCompletionResponse, +) +async def create_chat_completion( + request: CreateChatCompletionRequest, +) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: + completion_or_chunks = llama.create_chat_completion( + **request.dict(exclude={"model"}), + ) + + if request.stream: + + async def server_sent_events( + chat_chunks: Iterator[llama_cpp.ChatCompletionChunk], + ): + for chat_chunk in chat_chunks: + yield dict(data=json.dumps(chat_chunk)) + yield dict(data="[DONE]") + + chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore + + return EventSourceResponse( + server_sent_events(chunks), + ) + completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore + return completion diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 332cef9..be98bee 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -517,6 +517,99 @@ class Llama: stream=stream, ) + def _convert_text_completion_to_chat( + self, completion: Completion + ) -> ChatCompletion: + return { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": completion["choices"][0]["text"], + }, + "finish_reason": completion["choices"][0]["finish_reason"], + } + ], + "usage": completion["usage"], + } + + def _convert_text_completion_chunks_to_chat( + self, + chunks: Iterator[CompletionChunk], + ) -> Iterator[ChatCompletionChunk]: + for i, chunk in enumerate(chunks): + if i == 0: + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + }, + "finish_reason": None, + } + ], + } + yield { + "id": "chat" + chunk["id"], + "model": chunk["model"], + "created": chunk["created"], + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": { + "content": chunk["choices"][0]["text"], + }, + "finish_reason": chunk["choices"][0]["finish_reason"], + } + ], + } + + def create_chat_completion( + self, + messages: List[ChatCompletionMessage], + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: List[str] = [], + max_tokens: int = 128, + repeat_penalty: float = 1.1, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions.""" + chat_history = "\n".join( + f'{message["role"]} {message.get("user", "")}: {message["content"]}' + for message in messages + ) + PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: " + PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "] + completion_or_chunks = self( + prompt=PROMPT, + stop=PROMPT_STOP + stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + max_tokens=max_tokens, + repeat_penalty=repeat_penalty, + ) + if stream: + chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore + return self._convert_text_completion_chunks_to_chat(chunks) + else: + completion: Completion = completion_or_chunks # type: ignore + return self._convert_text_completion_to_chat(completion) + def __del__(self): if self.ctx is not None: llama_cpp.llama_free(self.ctx) diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index d8c0b83..3e9e803 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -1,5 +1,5 @@ -from typing import List, Optional, Dict, Literal -from typing_extensions import TypedDict +from typing import List, Optional, Dict, Literal, Union +from typing_extensions import TypedDict, NotRequired class EmbeddingUsage(TypedDict): @@ -55,3 +55,43 @@ class Completion(TypedDict): model: str choices: List[CompletionChoice] usage: CompletionUsage + + +class ChatCompletionMessage(TypedDict): + role: Union[Literal["assistant"], Literal["user"], Literal["system"]] + content: str + user: NotRequired[str] + + +class ChatCompletionChoice(TypedDict): + index: int + message: ChatCompletionMessage + finish_reason: Optional[str] + + +class ChatCompletion(TypedDict): + id: str + object: Literal["chat.completion"] + created: int + model: str + choices: List[ChatCompletionChoice] + usage: CompletionUsage + + +class ChatCompletionChunkDelta(TypedDict): + role: NotRequired[Literal["assistant"]] + content: NotRequired[str] + + +class ChatCompletionChunkChoice(TypedDict): + index: int + delta: ChatCompletionChunkDelta + finish_reason: Optional[str] + + +class ChatCompletionChunk(TypedDict): + id: str + model: str + object: Literal["chat.completion.chunk"] + created: int + choices: List[ChatCompletionChunkChoice]