Merge branch 'main' of https://github.com/abetlen/llama-cpp-python
This commit is contained in:
commit
e4c6f34d95
19 changed files with 6212 additions and 123 deletions
30
.github/workflows/test.yaml
vendored
Normal file
30
.github/workflows/test.yaml
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: "true"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip pytest cmake scikit-build
|
||||
python3 setup.py develop
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
|
@ -1,6 +1,7 @@
|
|||
# 🦙 Python Bindings for `llama.cpp`
|
||||
|
||||
[![Documentation](https://img.shields.io/badge/docs-passing-green.svg)](https://abetlen.github.io/llama-cpp-python)
|
||||
[![Tests](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml/badge.svg?branch=main)](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml)
|
||||
[![PyPI](https://img.shields.io/pypi/v/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
|
||||
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
|
||||
[![PyPI - License](https://img.shields.io/pypi/l/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
|
||||
|
@ -70,7 +71,7 @@ python3 setup.py develop
|
|||
|
||||
# How does this compare to other Python bindings of `llama.cpp`?
|
||||
|
||||
I wrote this package for my own use, I had two goals in mind:
|
||||
I originally wrote this package for my own use with two goals in mind:
|
||||
|
||||
- Provide a simple process to install `llama.cpp` and access the full C API in `llama.h` from Python
|
||||
- Provide a high-level Python API that can be used as a drop-in replacement for the OpenAI API so existing apps can be easily ported to use `llama.cpp`
|
||||
|
|
|
@ -71,8 +71,10 @@ python3 setup.py develop
|
|||
- sample
|
||||
- generate
|
||||
- create_embedding
|
||||
- embed
|
||||
- create_completion
|
||||
- __call__
|
||||
- create_chat_completion
|
||||
- token_bos
|
||||
- token_eos
|
||||
show_root_heading: true
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
"""Example FastAPI server for llama.cpp.
|
||||
"""
|
||||
import json
|
||||
from typing import List, Optional, Iterator
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
settings = Settings()
|
||||
llama = llama_cpp.Llama(
|
||||
settings.model,
|
||||
f16_kv=True,
|
||||
use_mlock=True,
|
||||
embedding=True,
|
||||
n_threads=6,
|
||||
n_batch=2048,
|
||||
)
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: str
|
||||
suffix: Optional[str] = Field(None)
|
||||
max_tokens: int = 16
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
logprobs: Optional[int] = Field(None)
|
||||
echo: bool = False
|
||||
stop: List[str] = []
|
||||
repeat_penalty: float = 1.1
|
||||
top_k: int = 40
|
||||
stream: bool = False
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
||||
"stop": ["\n", "###"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
def create_completion(request: CreateCompletionRequest):
|
||||
if request.stream:
|
||||
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
|
||||
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
||||
return llama(**request.dict())
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
model: Optional[str]
|
||||
input: str
|
||||
user: Optional[str]
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"input": "The food was delicious and the waiter...",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
def create_embedding(request: CreateEmbeddingRequest):
|
||||
return llama.create_embedding(request.input)
|
181
examples/high_level_api/fastapi_server.py
Normal file
181
examples/high_level_api/fastapi_server.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
"""Example FastAPI server for llama.cpp.
|
||||
|
||||
To run this example:
|
||||
|
||||
```bash
|
||||
pip install fastapi uvicorn sse-starlette
|
||||
export MODEL=../models/7B/...
|
||||
uvicorn fastapi_server_chat:app --reload
|
||||
```
|
||||
|
||||
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Literal, Union, Iterator
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str
|
||||
n_ctx: int = 2048
|
||||
n_batch: int = 2048
|
||||
n_threads: int = os.cpu_count() or 1
|
||||
f16_kv: bool = True
|
||||
use_mlock: bool = True
|
||||
embedding: bool = True
|
||||
last_n_tokens_size: int = 64
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
settings = Settings()
|
||||
llama = llama_cpp.Llama(
|
||||
settings.model,
|
||||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
embedding=settings.embedding,
|
||||
n_threads=settings.n_threads,
|
||||
n_batch=settings.n_batch,
|
||||
n_ctx=settings.n_ctx,
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
)
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: str
|
||||
suffix: Optional[str] = Field(None)
|
||||
max_tokens: int = 16
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
logprobs: Optional[int] = Field(None)
|
||||
echo: bool = False
|
||||
stop: List[str] = []
|
||||
repeat_penalty: float = 1.1
|
||||
top_k: int = 40
|
||||
stream: bool = False
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
||||
"stop": ["\n", "###"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
def create_completion(request: CreateCompletionRequest):
|
||||
if request.stream:
|
||||
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
|
||||
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
|
||||
return llama(**request.dict())
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
model: Optional[str]
|
||||
input: str
|
||||
user: Optional[str]
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"input": "The food was delicious and the waiter...",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
def create_embedding(request: CreateEmbeddingRequest):
|
||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
|
||||
content: str
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
model: Optional[str]
|
||||
messages: List[ChatCompletionRequestMessage]
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
stream: bool = False
|
||||
stop: List[str] = []
|
||||
max_tokens: int = 128
|
||||
repeat_penalty: float = 1.1
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
ChatCompletionRequestMessage(
|
||||
role="system", content="You are a helpful assistant."
|
||||
),
|
||||
ChatCompletionRequestMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=CreateChatCompletionResponse,
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
|
||||
completion_or_chunks = llama.create_chat_completion(
|
||||
**request.dict(exclude={"model"}),
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
|
||||
async def server_sent_events(
|
||||
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
|
||||
):
|
||||
for chat_chunk in chat_chunks:
|
||||
yield dict(data=json.dumps(chat_chunk))
|
||||
yield dict(data="[DONE]")
|
||||
|
||||
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
|
||||
|
||||
return EventSourceResponse(
|
||||
server_sent_events(chunks),
|
||||
)
|
||||
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
|
||||
return completion
|
|
@ -11,7 +11,7 @@ llm = Llama(model_path=args.model)
|
|||
|
||||
output = llm(
|
||||
"Question: What are the names of the planets in the solar system? Answer: ",
|
||||
max_tokens=1,
|
||||
max_tokens=48,
|
||||
stop=["Q:", "\n"],
|
||||
echo=True,
|
||||
)
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
from llama_cpp import Llama
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default=".//models/...")
|
||||
parser.add_argument("-m", "--model", type=str, default="./models/...")
|
||||
args = parser.parse_args()
|
||||
|
||||
llm = Llama(model_path=args.model)
|
25
examples/low_level_api/quantize.py
Normal file
25
examples/low_level_api/quantize.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import os
|
||||
import argparse
|
||||
import llama_cpp
|
||||
|
||||
|
||||
def main(args):
|
||||
if not os.path.exists(fname_inp):
|
||||
raise RuntimeError(f"Input file does not exist ({fname_inp})")
|
||||
if os.path.exists(fname_out):
|
||||
raise RuntimeError(f"Output file already exists ({fname_out})")
|
||||
fname_inp = args.fname_inp.encode("utf-8")
|
||||
fname_out = args.fname_out.encode("utf-8")
|
||||
itype = args.itype
|
||||
return_code = llama_cpp.llama_model_quantize(fname_inp, fname_out, itype)
|
||||
if return_code != 0:
|
||||
raise RuntimeError("Failed to quantize model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("fname_inp", type=str, help="Path to input model")
|
||||
parser.add_argument("fname_out", type=str, help="Path to output model")
|
||||
parser.add_argument("type", type=int, help="Type of quantization (2: q4_0, 3: q4_1)")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
5540
examples/notebooks/PerformanceTuning.ipynb
Normal file
5540
examples/notebooks/PerformanceTuning.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from collections import deque
|
||||
|
||||
from . import llama_cpp
|
||||
|
@ -27,6 +28,7 @@ class Llama:
|
|||
n_threads: Optional[int] = None,
|
||||
n_batch: int = 8,
|
||||
last_n_tokens_size: int = 64,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Load a llama.cpp model from `model_path`.
|
||||
|
||||
|
@ -43,6 +45,7 @@ class Llama:
|
|||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model path does not exist.
|
||||
|
@ -50,6 +53,7 @@ class Llama:
|
|||
Returns:
|
||||
A Llama instance.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.model_path = model_path
|
||||
|
||||
self.params = llama_cpp.llama_context_default_params()
|
||||
|
@ -68,7 +72,7 @@ class Llama:
|
|||
maxlen=self.last_n_tokens_size,
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.n_batch = n_batch
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
|
||||
self.n_threads = n_threads or multiprocessing.cpu_count()
|
||||
|
||||
|
@ -79,6 +83,9 @@ class Llama:
|
|||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||
|
||||
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
||||
"""Tokenize a string.
|
||||
|
||||
|
@ -169,11 +176,6 @@ class Llama:
|
|||
The sampled token.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
# Temporary workaround for https://github.com/ggerganov/llama.cpp/issues/684
|
||||
if temp == 0.0:
|
||||
temp = 1.0
|
||||
top_p = 0.0
|
||||
top_k = 1
|
||||
return llama_cpp.llama_sample_top_p_top_k(
|
||||
ctx=self.ctx,
|
||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
|
@ -239,6 +241,15 @@ class Llama:
|
|||
An embedding object.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
|
||||
if self.params.embedding == False:
|
||||
raise RuntimeError(
|
||||
"Llama model must be created with embedding=True to call this method"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
||||
tokens = self.tokenize(input.encode("utf-8"))
|
||||
self.reset()
|
||||
self.eval(tokens)
|
||||
|
@ -246,6 +257,10 @@ class Llama:
|
|||
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
||||
: llama_cpp.llama_n_embd(self.ctx)
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self.ctx)
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
|
@ -262,6 +277,17 @@ class Llama:
|
|||
},
|
||||
}
|
||||
|
||||
def embed(self, input: str) -> List[float]:
|
||||
"""Embed a string.
|
||||
|
||||
Args:
|
||||
input: The utf-8 encoded string to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings
|
||||
"""
|
||||
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
|
||||
|
||||
def _create_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -275,10 +301,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
Generator[Completion, None, None],
|
||||
Generator[CompletionChunk, None, None],
|
||||
]:
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
|
||||
assert self.ctx is not None
|
||||
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||
created = int(time.time())
|
||||
|
@ -288,6 +311,9 @@ class Llama:
|
|||
text = b""
|
||||
returned_characters = 0
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
||||
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
||||
raise ValueError(
|
||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
|
@ -384,6 +410,9 @@ class Llama:
|
|||
if logprobs is not None:
|
||||
raise NotImplementedError("logprobs not implemented")
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self.ctx)
|
||||
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -417,7 +446,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
) -> Union[Completion, Generator[CompletionChunk, None, None]]:
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
Args:
|
||||
|
@ -454,7 +483,7 @@ class Llama:
|
|||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
chunks: Generator[CompletionChunk, None, None] = completion_or_chunks
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||
return chunks
|
||||
completion: Completion = next(completion_or_chunks) # type: ignore
|
||||
return completion
|
||||
|
@ -472,7 +501,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
):
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
Args:
|
||||
|
@ -509,11 +538,158 @@ class Llama:
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
self, completion: Completion
|
||||
) -> ChatCompletion:
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
"created": completion["created"],
|
||||
"model": completion["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
"usage": completion["usage"],
|
||||
}
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
self,
|
||||
chunks: Iterator[CompletionChunk],
|
||||
) -> Iterator[ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: List[ChatCompletionMessage],
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: List[str] = [],
|
||||
max_tokens: int = 128,
|
||||
repeat_penalty: float = 1.1,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
||||
Args:
|
||||
messages: A list of messages to generate a response for.
|
||||
temperature: The temperature to use for sampling.
|
||||
top_p: The top-p value to use for sampling.
|
||||
top_k: The top-k value to use for sampling.
|
||||
stream: Whether to stream the results.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
repeat_penalty: The penalty to apply to repeated tokens.
|
||||
|
||||
Returns:
|
||||
Generated chat completion or a stream of chat completion chunks.
|
||||
"""
|
||||
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
|
||||
chat_history = "\n".join(
|
||||
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
|
||||
for message in messages
|
||||
)
|
||||
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
|
||||
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
|
||||
completion_or_chunks = self(
|
||||
prompt=PROMPT,
|
||||
stop=PROMPT_STOP + stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: Completion = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_to_chat(completion)
|
||||
|
||||
def __del__(self):
|
||||
if self.ctx is not None:
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(
|
||||
verbose=self.verbose,
|
||||
model_path=self.model_path,
|
||||
n_ctx=self.params.n_ctx,
|
||||
n_parts=self.params.n_parts,
|
||||
seed=self.params.seed,
|
||||
f16_kv=self.params.f16_kv,
|
||||
logits_all=self.params.logits_all,
|
||||
vocab_only=self.params.vocab_only,
|
||||
use_mlock=self.params.use_mlock,
|
||||
embedding=self.params.embedding,
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
last_n_tokens_data=self.last_n_tokens_data,
|
||||
tokens_consumed=self.tokens_consumed,
|
||||
n_batch=self.n_batch,
|
||||
n_threads=self.n_threads,
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(
|
||||
model_path=state["model_path"],
|
||||
n_ctx=state["n_ctx"],
|
||||
n_parts=state["n_parts"],
|
||||
seed=state["seed"],
|
||||
f16_kv=state["f16_kv"],
|
||||
logits_all=state["logits_all"],
|
||||
vocab_only=state["vocab_only"],
|
||||
use_mlock=state["use_mlock"],
|
||||
embedding=state["embedding"],
|
||||
n_threads=state["n_threads"],
|
||||
n_batch=state["n_batch"],
|
||||
last_n_tokens_size=state["last_n_tokens_size"],
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
self.last_n_tokens_data=state["last_n_tokens_data"]
|
||||
self.tokens_consumed=state["tokens_consumed"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def token_eos() -> llama_cpp.llama_token:
|
||||
"""Return the end-of-sequence token."""
|
||||
|
|
|
@ -125,12 +125,12 @@ _lib.llama_free.restype = None
|
|||
# TODO: not great API - very likely to change
|
||||
# Returns 0 on success
|
||||
def llama_model_quantize(
|
||||
fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int
|
||||
fname_inp: bytes, fname_out: bytes, itype: c_int
|
||||
) -> c_int:
|
||||
return _lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
|
||||
return _lib.llama_model_quantize(fname_inp, fname_out, itype)
|
||||
|
||||
|
||||
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
|
||||
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
|
||||
_lib.llama_model_quantize.restype = c_int
|
||||
|
||||
# Returns the KV cache that will contain the context for the
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List, Optional, Dict, Literal
|
||||
from typing_extensions import TypedDict
|
||||
from typing import List, Optional, Dict, Union
|
||||
from typing_extensions import TypedDict, NotRequired, Literal
|
||||
|
||||
|
||||
class EmbeddingUsage(TypedDict):
|
||||
|
@ -55,3 +55,43 @@ class Completion(TypedDict):
|
|||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionMessage(TypedDict):
|
||||
role: Union[Literal["assistant"], Literal["user"], Literal["system"]]
|
||||
content: str
|
||||
user: NotRequired[str]
|
||||
|
||||
|
||||
class ChatCompletionChoice(TypedDict):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletion(TypedDict):
|
||||
id: str
|
||||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionChunkDelta(TypedDict):
|
||||
role: NotRequired[Literal["assistant"]]
|
||||
content: NotRequired[str]
|
||||
|
||||
|
||||
class ChatCompletionChunkChoice(TypedDict):
|
||||
index: int
|
||||
delta: ChatCompletionChunkDelta
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletionChunk(TypedDict):
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: int
|
||||
choices: List[ChatCompletionChunkChoice]
|
||||
|
|
88
poetry.lock
generated
88
poetry.lock
generated
|
@ -1,5 +1,24 @@
|
|||
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "22.2.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
|
||||
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
|
||||
dev = ["attrs[docs,tests]"]
|
||||
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
|
||||
tests = ["attrs[tests-no-zope]", "zope.interface"]
|
||||
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.1.0"
|
||||
|
@ -328,6 +347,21 @@ files = [
|
|||
{file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.1.1"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"},
|
||||
{file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "ghp-import"
|
||||
version = "2.1.0"
|
||||
|
@ -415,6 +449,18 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
|
|||
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jaraco-classes"
|
||||
version = "3.2.3"
|
||||
|
@ -821,6 +867,22 @@ files = [
|
|||
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.0.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
|
||||
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.21"
|
||||
|
@ -864,6 +926,30 @@ files = [
|
|||
markdown = ">=3.2"
|
||||
pyyaml = "*"
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.2.2"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"},
|
||||
{file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=19.2.0"
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
|
@ -1281,4 +1367,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8.1"
|
||||
content-hash = "cffaf5e2e66ade4f429d0e938277d4fa2c4878ca7338c3c4f91721a7d3aff91b"
|
||||
content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llama_cpp"
|
||||
version = "0.1.17"
|
||||
version = "0.1.22"
|
||||
description = "Python bindings for the llama.cpp library"
|
||||
authors = ["Andrei Betlen <abetlen@gmail.com>"]
|
||||
license = "MIT"
|
||||
|
@ -23,6 +23,7 @@ twine = "^4.0.2"
|
|||
mkdocs = "^1.4.2"
|
||||
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
|
||||
mkdocs-material = "^9.1.4"
|
||||
pytest = "^7.2.2"
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
|
|
10
setup.py
10
setup.py
|
@ -10,7 +10,7 @@ setup(
|
|||
description="A Python wrapper for llama.cpp",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version="0.1.17",
|
||||
version="0.1.22",
|
||||
author="Andrei Betlen",
|
||||
author_email="abetlen@gmail.com",
|
||||
license="MIT",
|
||||
|
@ -19,4 +19,12 @@ setup(
|
|||
"typing-extensions>=4.5.0",
|
||||
],
|
||||
python_requires=">=3.7",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
],
|
||||
)
|
||||
|
|
96
tests/test_llama.py
Normal file
96
tests/test_llama.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import llama_cpp
|
||||
|
||||
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
|
||||
|
||||
|
||||
def test_llama():
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
|
||||
assert llama
|
||||
assert llama.ctx is not None
|
||||
|
||||
text = b"Hello World"
|
||||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
||||
|
||||
|
||||
def test_llama_patch(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
|
||||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
|
||||
output_text = " jumps over the lazy dog."
|
||||
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||
token_eos = llama.token_eos()
|
||||
n = 0
|
||||
|
||||
def mock_sample(*args, **kwargs):
|
||||
nonlocal n
|
||||
if n < len(output_tokens):
|
||||
n += 1
|
||||
return output_tokens[n - 1]
|
||||
else:
|
||||
return token_eos
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||
|
||||
text = "The quick brown fox"
|
||||
|
||||
## Test basic completion until eos
|
||||
n = 0 # reset
|
||||
completion = llama.create_completion(text, max_tokens=20)
|
||||
assert completion["choices"][0]["text"] == output_text
|
||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
## Test streaming completion until eos
|
||||
n = 0 # reset
|
||||
chunks = llama.create_completion(text, max_tokens=20, stream=True)
|
||||
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
|
||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
## Test basic completion until stop sequence
|
||||
n = 0 # reset
|
||||
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
|
||||
assert completion["choices"][0]["text"] == " jumps over the "
|
||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
## Test streaming completion until stop sequence
|
||||
n = 0 # reset
|
||||
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
|
||||
assert (
|
||||
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
|
||||
)
|
||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
## Test basic completion until length
|
||||
n = 0 # reset
|
||||
completion = llama.create_completion(text, max_tokens=2)
|
||||
assert completion["choices"][0]["text"] == " j"
|
||||
assert completion["choices"][0]["finish_reason"] == "length"
|
||||
|
||||
## Test streaming completion until length
|
||||
n = 0 # reset
|
||||
chunks = llama.create_completion(text, max_tokens=2, stream=True)
|
||||
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
|
||||
assert completion["choices"][0]["finish_reason"] == "length"
|
||||
|
||||
|
||||
def test_llama_pickle():
|
||||
import pickle
|
||||
import tempfile
|
||||
fp = tempfile.TemporaryFile()
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
pickle.dump(llama, fp)
|
||||
fp.seek(0)
|
||||
llama = pickle.load(fp)
|
||||
|
||||
assert llama
|
||||
assert llama.ctx is not None
|
||||
|
||||
text = b"Hello World"
|
||||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
Loading…
Add table
Reference in a new issue