Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into local-lib

This commit is contained in:
Mug 2023-04-10 17:00:42 +02:00
commit 4132293d2d
23 changed files with 1307 additions and 42 deletions

View file

@ -0,0 +1,71 @@
name: Build Release
on: workflow_dispatch
permissions:
contents: write
jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
# Used to host cibuildwheel
- uses: actions/setup-python@v3
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.12.1
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
- uses: actions/upload-artifact@v3
with:
path: ./wheelhouse/*.whl
build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
- name: Build source distribution
run: |
python setup.py sdist
- uses: actions/upload-artifact@v3
with:
path: ./dist/*.tar.gz
release:
name: Release
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
steps:
- uses: actions/download-artifact@v3
with:
name: artifact
path: dist
- uses: softprops/action-gh-release@v1
with:
files: dist/*
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

30
.github/workflows/publish-to-test.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
# Based on: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
name: Publish to TestPyPI
on: workflow_dispatch
jobs:
build-n-publish:
name: Build and publish
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
- name: Build source distribution
run: |
python setup.py sdist
- name: Publish to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/

31
.github/workflows/publish.yaml vendored Normal file
View file

@ -0,0 +1,31 @@
name: Publish to PyPI
# Based on: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
on: workflow_dispatch
jobs:
build-n-publish:
name: Build and publish
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
- name: Build source distribution
run: |
python setup.py sdist
- name: Publish distribution to PyPI
# TODO: move to tag based releases
# if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}

View file

@ -1,12 +1,15 @@
name: Tests
on:
pull_request:
branches:
- main
push:
branches:
- main
jobs:
build:
build-linux:
runs-on: ubuntu-latest
strategy:
@ -23,8 +26,54 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build
python3 setup.py develop
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
pip install . -v
- name: Test with pytest
run: |
pytest
build-windows:
runs-on: windows-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
pip install . -v
- name: Test with pytest
run: |
pytest
build-macos:
runs-on: macos-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
with:
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
pip install . -v
- name: Test with pytest
run: |
pytest

2
.gitignore vendored
View file

@ -163,4 +163,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

View file

@ -2,8 +2,26 @@ cmake_minimum_required(VERSION 3.4...3.22)
project(llama_cpp)
if (UNIX)
add_custom_command(
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so
COMMAND make libllama.so
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp
)
add_custom_target(
run ALL
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so
)
install(
FILES ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so
DESTINATION llama_cpp
)
else()
set(BUILD_SHARED_LIBS "On")
add_subdirectory(vendor/llama.cpp)
install(TARGETS llama LIBRARY DESTINATION llama_cpp)
install(
TARGETS llama
LIBRARY DESTINATION llama_cpp
RUNTIME DESTINATION llama_cpp
)
endif(UNIX)

View file

@ -15,7 +15,7 @@ This package provides:
- OpenAI-like API
- LangChain compatibility
# Installation
## Installation
Install from PyPI:
@ -23,18 +23,18 @@ Install from PyPI:
pip install llama-cpp-python
```
# Usage
## High-level API
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="models/7B/...")
>>> llm = Llama(model_path="./models/7B/ggml-model.bin")
>>> output = llm("Q: Name the planets in the solar system? A: ", max_tokens=32, stop=["Q:", "\n"], echo=True)
>>> print(output)
{
"id": "cmpl-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
"object": "text_completion",
"created": 1679561337,
"model": "models/7B/...",
"model": "./models/7B/ggml-model.bin",
"choices": [
{
"text": "Q: Name the planets in the solar system? A: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune and Pluto.",
@ -51,6 +51,27 @@ pip install llama-cpp-python
}
```
## Web Server
`llama-cpp-python` offers a web server which aims to act as a drop-in replacement for the OpenAI API.
This allows you to use llama.cpp compatible models with any OpenAI compatible client (language libraries, services, etc).
To install the server package and get started:
```bash
pip install llama-cpp-python[server]
export MODEL=./models/7B/ggml-model.bin
python3 -m llama_cpp.server
```
Navigate to [http://localhost:8000/docs](http://localhost:8000/docs) to see the OpenAPI documentation.
## Low-level API
The low-level API is a direct `ctypes` binding to the C API provided by `llama.cpp`.
The entire API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and should mirror [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
# Documentation
Documentation is available at [https://abetlen.github.io/llama-cpp-python](https://abetlen.github.io/llama-cpp-python).

View file

@ -1,5 +1,9 @@
# 🦙 Python Bindings for `llama.cpp`
# Getting Started
## 🦙 Python Bindings for `llama.cpp`
[![Documentation](https://img.shields.io/badge/docs-passing-green.svg)](https://abetlen.github.io/llama-cpp-python)
[![Tests](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml/badge.svg?branch=main)](https://github.com/abetlen/llama-cpp-python/actions/workflows/test.yaml)
[![PyPI](https://img.shields.io/pypi/v/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
[![PyPI - License](https://img.shields.io/pypi/l/llama-cpp-python)](https://pypi.org/project/llama-cpp-python/)
@ -21,18 +25,18 @@ Install from PyPI:
pip install llama-cpp-python
```
## Usage
## High-level API
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="models/7B/...")
>>> llm = Llama(model_path="./models/7B/ggml-model.bin")
>>> output = llm("Q: Name the planets in the solar system? A: ", max_tokens=32, stop=["Q:", "\n"], echo=True)
>>> print(output)
{
"id": "cmpl-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
"object": "text_completion",
"created": 1679561337,
"model": "models/7B/...",
"model": "./models/7B/ggml-model.bin",
"choices": [
{
"text": "Q: Name the planets in the solar system? A: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune and Pluto.",
@ -49,8 +53,33 @@ pip install llama-cpp-python
}
```
## Web Server
`llama-cpp-python` offers a web server which aims to act as a drop-in replacement for the OpenAI API.
This allows you to use llama.cpp compatible models with any OpenAI compatible client (language libraries, services, etc).
To install the server package and get started:
```bash
pip install llama-cpp-python[server]
export MODEL=./models/7B/ggml-model.bin
python3 -m llama_cpp.server
```
Navigate to [http://localhost:8000/docs](http://localhost:8000/docs) to see the OpenAPI documentation.
## Low-level API
The low-level API is a direct `ctypes` binding to the C API provided by `llama.cpp`.
The entire API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and should mirror [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
## Development
This package is under active development and I welcome any contributions.
To get started, clone the repository and install the package in development mode:
```bash
git clone git@github.com:abetlen/llama-cpp-python.git
git submodule update --init --recursive

View file

@ -4,7 +4,7 @@ To run this example:
```bash
pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/...
export MODEL=../models/7B/ggml-model.bin
uvicorn fastapi_server_chat:app --reload
```
@ -13,7 +13,8 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import json
from typing import List, Optional, Literal, Union, Iterator
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
@ -26,10 +27,10 @@ from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str
n_ctx: int = 2048
n_batch: int = 2048
n_threads: int = os.cpu_count() or 1
n_batch: int = 8
n_threads: int = int(os.cpu_count() / 2) or 1
f16_kv: bool = True
use_mlock: bool = True
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True
last_n_tokens_size: int = 64
@ -64,13 +65,24 @@ class CreateCompletionRequest(BaseModel):
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
logprobs: Optional[int] = Field(None)
echo: bool = False
stop: List[str] = []
repeat_penalty: float = 1.1
top_k: int = 40
stream: bool = False
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = 40
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
@ -91,7 +103,20 @@ def create_completion(request: CreateCompletionRequest):
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama(**request.dict())
return llama(
**request.dict(
exclude={
"model",
"n",
"logprobs",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
}
)
)
class CreateEmbeddingRequest(BaseModel):
@ -132,6 +157,16 @@ class CreateChatCompletionRequest(BaseModel):
stream: bool = False
stop: List[str] = []
max_tokens: int = 128
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
repeat_penalty: float = 1.1
class Config:
@ -160,7 +195,16 @@ async def create_chat_completion(
request: CreateChatCompletionRequest,
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(exclude={"model"}),
**request.dict(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}
),
)
if request.stream:
@ -179,3 +223,40 @@ async def create_chat_completion(
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}
if __name__ == "__main__":
import os
import uvicorn
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))

View file

@ -3,7 +3,7 @@ import argparse
from llama_cpp import Llama
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=".//models/...")
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-model.bin")
args = parser.parse_args()
llm = Llama(model_path=args.model, embedding=True)

View file

@ -4,7 +4,7 @@ import argparse
from llama_cpp import Llama
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="./models/...")
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
args = parser.parse_args()
llm = Llama(model_path=args.model)

View file

@ -4,7 +4,7 @@ import argparse
from llama_cpp import Llama
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="./models/...")
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
args = parser.parse_args()
llm = Llama(model_path=args.model)

View file

@ -29,7 +29,7 @@ class LlamaLLM(LLM):
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="./models/...")
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
args = parser.parse_args()
# Load the model

View file

@ -0,0 +1,148 @@
import os
import argparse
from dataclasses import dataclass, field
from typing import List, Optional
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
@dataclass
class GptParams:
seed: int = -1
n_threads: int = min(4, os.cpu_count() or 1)
n_predict: int = 128
repeat_last_n: int = 64
n_parts: int = -1
n_ctx: int = 512
n_batch: int = 8
n_keep: int = 0
top_k: int = 40
top_p: float = 0.95
temp: float = 0.80
repeat_penalty: float = 1.10
model: str = "./models/llama-7B/ggml-model.bin"
prompt: str = ""
input_prefix: str = " "
antiprompt: List[str] = field(default_factory=list)
memory_f16: bool = True
random_prompt: bool = False
use_color: bool = False
interactive: bool = False
embedding: bool = False
interactive_start: bool = False
instruct: bool = False
ignore_eos: bool = False
perplexity: bool = False
use_mlock: bool = False
mem_test: bool = False
verbose_prompt: bool = False
file: str = None
# If chat ended prematurely, append this to the conversation to fix it.
# Set to "\nUser:" etc.
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
fix_prefix: str = " "
output_postfix: str = ""
input_echo: bool = True,
# Default instructions for Alpaca
# switch to "Human" and "Assistant" for Vicuna.
# TODO: TBD how they are gonna handle this upstream
instruct_inp_prefix: str="\n\n### Instruction:\n\n"
instruct_inp_suffix: str="\n\n### Response:\n\n"
def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
if params is None:
params = GptParams()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed")
parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads")
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx")
parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict")
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep")
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
parser.add_argument(
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
)
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
parser.add_argument(
"--interactive-start",
action="store_true",
help="run in interactive mode",
dest="interactive"
)
parser.add_argument(
"--interactive-first",
action="store_true",
help="run in interactive mode and wait for input right away",
dest="interactive_start"
)
parser.add_argument(
"-ins",
"--instruct",
action="store_true",
help="run in instruction mode (use with Alpaca or Vicuna models)",
dest="instruct"
)
parser.add_argument(
"--color",
action="store_true",
help="colorise output to distinguish prompt and user input from generations",
dest="use_color"
)
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
parser.add_argument(
"-r",
"--reverse-prompt",
type=str,
action='append',
help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
dest="antiprompt"
)
parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos")
parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts")
parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt")
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix")
parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix")
parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo")
args = parser.parse_args(argv)
return args
def gpt_random_prompt(rng):
return [
"So",
"Once upon a time",
"When",
"The",
"After",
"If",
"import",
"He",
"She",
"They",
][rng % 10]
if __name__ == "__main__":
print(GptParams(gpt_params_parse()))

View file

@ -0,0 +1,393 @@
"""
This is an example implementation of main.cpp from llama.cpp
Quirks:
* Its not exactly alike since this port is designed around programmatic I/O
* Input is always echoed if on, so it should be turned off when using "input()"
* The first antiprompt should be the userprompt like "\nUser:",
because its added when n_predict is reached (aka generation ended prematurely)
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
* Instruction mode adds its own antiprompt.
You should also still be feeding the model with a "primer" prompt that
shows it the expected format.
"""
import sys
from time import time
from os import cpu_count
import llama_cpp
from common import GptParams, gpt_params_parse, gpt_random_prompt
ANSI_COLOR_RESET = "\x1b[0m"
ANSI_COLOR_YELLOW = "\x1b[33m"
ANSI_BOLD = "\x1b[1m"
ANSI_COLOR_GREEN = "\x1b[32m"
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
# A LLaMA interactive session
class LLaMAInteract:
def __init__(self, params: GptParams) -> None:
# input args
self.params = params
if (self.params.perplexity):
raise NotImplementedError("""************
please use the 'perplexity' tool for perplexity calculations
************""")
if (self.params.embedding):
raise NotImplementedError("""************
please use the 'embedding' tool for embedding calculations
************""")
if (self.params.n_ctx > 2048):
print(f"""warning: model does not support \
context sizes greater than 2048 tokens ({self.params.n_ctx} \
specified) expect poor results""", file=sys.stderr)
if (self.params.seed <= 0):
self.params.seed = int(time())
print(f"seed = {self.params.seed}", file=sys.stderr)
if (self.params.random_prompt):
self.params.prompt = gpt_random_prompt(self.params.seed)
# runtime args
self.input_consumed = 0
self.n_past = 0
self.first_antiprompt = []
self.remaining_tokens = self.params.n_predict
self.output_echo = self.params.input_echo
# model load
self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
# determine the required inference memory per token:
if (self.params.mem_test):
tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
llama_cpp.llama_print_timings(self.ctx)
self.exit()
return
# create internal context
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
# Add a space in front of the first character to match OG llama tokenizer behavior
self.params.prompt = " " + self.params.prompt
# Load prompt file
if (self.params.file):
with open(self.params.file) as f:
self.params.prompt = f.read()
# tokenize the prompt
self.embd = []
self.embd_inp = self._tokenize(self.params.prompt)
if (len(self.embd_inp) > self.params.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
# number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp)
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user
if (self.params.instruct):
self.params.interactive_start = True
self.first_antiprompt.append(self._tokenize(self.params.instruct_inp_prefix.strip(), False))
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
self.params.interactive = True
# determine newline token
self.llama_token_newline = self._tokenize("\n", False)
if (self.params.verbose_prompt):
print(f"""
prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
for i in range(len(self.embd_inp)):
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
if (self.params.interactive):
print("interactive mode on.", file=sys.stderr)
if (len(self.params.antiprompt) > 0):
for antiprompt in self.params.antiprompt:
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr)
if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
print(f"""sampling: temp = {self.params.temp},\
top_k = {self.params.top_k},\
top_p = {self.params.top_p},\
repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty}
generate: n_ctx = {self.n_ctx}, \
n_batch = {self.params.n_batch}, \
n_predict = {self.params.n_predict}, \
n_keep = {self.params.n_keep}
""", file=sys.stderr)
# determine antiprompt tokens
for i in self.params.antiprompt:
self.first_antiprompt.append(self._tokenize(i, False))
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
if (params.interactive):
print("""== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\\'.
""", file=sys.stderr)
self.set_color(CONSOLE_COLOR_PROMPT)
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n]
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
def set_color(self, c):
if (self.params.use_color):
print(c, end="")
# generate tokens
def generate(self):
while self.remaining_tokens > 0 or self.params.interactive:
# predict
if len(self.embd) > 0:
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.params.n_keep
self.n_past = self.params.n_keep
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
]
self.embd = _insert + self.embd
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
) != 0):
raise Exception("Failed to llama_eval!")
self.n_past += len(self.embd)
self.embd = []
if len(self.embd_inp) <= self.input_consumed:
# out of user input, sample next token
#TODO: self.params.ignore_eos
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k(
self.ctx,
(llama_cpp.llama_token * len(_arr))(*_arr),
len(_arr),
self.params.top_k,
self.params.top_p,
self.params.temp,
self.params.repeat_penalty,
)
self.last_n_tokens.pop(0)
self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
id = self.llama_token_newline[0]
if (self.use_antiprompt()):
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
# add it to the context
self.embd.append(id)
# echo this to console
self.output_echo = True
# decrement remaining sampling budget
self.remaining_tokens -= 1
else:
# output to console if input echo is on
self.output_echo = self.params.input_echo
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
self.embd.append(self.embd_inp[self.input_consumed])
self.last_n_tokens.pop(0)
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
self.input_consumed += 1
if len(self.embd) >= self.params.n_batch:
break
# display tokens
if self.output_echo:
for id in self.embd:
yield id
# reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
self.set_color(CONSOLE_COLOR_DEFAULT)
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop
if (self.use_antiprompt()):
if True in [
i == self.last_n_tokens[-len(i):]
for i in self.first_antiprompt
]:
break
# if we are using instruction mode, and we have processed the initial prompt
if (self.n_past > 0 and self.params.interactive_start):
break
# end of text token
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
if (not self.params.instruct):
for i in " [end of text]\n":
yield i
break
# respect n_predict even if antiprompt is present
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
# If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self.params.instruct:
self.embd_inp += self.first_antiprompt[0]
self.n_remain = self.params.n_predict
break
self.params.interactive_start = False
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.exit()
def exit(self):
llama_cpp.llama_free(self.ctx)
self.set_color(CONSOLE_COLOR_DEFAULT)
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# write input
def input(self, prompt: str):
if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt)
if (self.params.instruct):
self.embd_inp += self.inp_suffix
# write output
def output(self):
self.remaining_tokens = self.params.n_predict
for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# read user input
def read_input(self):
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
# interactive mode
def interact(self):
for i in self.output():
print(i,end="",flush=True)
self.params.input_echo = False
while self.params.interactive:
self.set_color(CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
print('\n> ', end="")
self.input(self.read_input())
else:
print(self.params.input_prefix, end="")
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.output_postfix}")
print(self.params.output_postfix,end="")
self.set_color(CONSOLE_COLOR_DEFAULT)
try:
for i in self.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
self.set_color(CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
print(self.params.fix_prefix,end="")
self.input(self.params.fix_prefix)
if __name__ == "__main__":
from datetime import datetime
USER_NAME="User"
AI_NAME="ChatLLaMa"
time_now = datetime.now()
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
The transcript only includes text, it does not include markup like HTML and Markdown.
{USER_NAME}: Hello, {AI_NAME}!
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
{USER_NAME}: What time is it?
{AI_NAME}: It is {time_now.strftime("%H:%M")}.
{USER_NAME}: What year is it?
{AI_NAME}: We are in {time_now.strftime("%Y")}.
{USER_NAME}: What is a cat?
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{USER_NAME}: Name a color.
{AI_NAME}: Blue
{USER_NAME}:"""
args = gpt_params_parse()
params = GptParams(**vars(args))
with LLaMAInteract(params) as m:
m.interact()

View file

@ -9,7 +9,7 @@ N_THREADS = multiprocessing.cpu_count()
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
lparams = llama_cpp.llama_context_default_params()
ctx = llama_cpp.llama_init_from_file(b"models/ggml-alpaca-7b-q4.bin", lparams)
ctx = llama_cpp.llama_init_from_file(b"../models/7B/ggml-model.bin", lparams)
# determine the required inference memory per token:
tmp = [0, 1, 2, 3]

View file

@ -0,0 +1,104 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject text_completion id=cmpl-ad3ba53d-407c-466b-bd5f-97cb8987af83 at 0x7f6adc12d900> JSON: {\n",
" \"choices\": [\n",
" {\n",
" \"finish_reason\": \"length\",\n",
" \"index\": 0,\n",
" \"logprobs\": null,\n",
" \"text\": \" over the lazy dog.\"\n",
" }\n",
" ],\n",
" \"created\": 1680960690,\n",
" \"id\": \"cmpl-ad3ba53d-407c-466b-bd5f-97cb8987af83\",\n",
" \"model\": \"models/ggml-alpaca.bin\",\n",
" \"object\": \"text_completion\",\n",
" \"usage\": {\n",
" \"completion_tokens\": 5,\n",
" \"prompt_tokens\": 8,\n",
" \"total_tokens\": 13\n",
" }\n",
"}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"\n",
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n",
"\n",
"openai.Completion.create(\n",
" model=\"text-davinci-003\", # currently can be anything\n",
" prompt=\"The quick brown fox jumps\",\n",
" max_tokens=5,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' over the lazy dog'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
"os.environ[\"OPENAI_API_BASE\"] = \"http://100.64.159.73:8000/v1\"\n",
"\n",
"from langchain.llms import OpenAI\n",
"\n",
"llms = OpenAI()\n",
"llms(\n",
" prompt=\"The quick brown fox jumps\",\n",
" stop=[\".\", \"\\n\"],\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -23,6 +23,7 @@ class Llama:
f16_kv: bool = False,
logits_all: bool = False,
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
embedding: bool = False,
n_threads: Optional[int] = None,
@ -40,6 +41,7 @@ class Llama:
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
@ -63,6 +65,7 @@ class Llama:
self.params.f16_kv = f16_kv
self.params.logits_all = logits_all
self.params.vocab_only = vocab_only
self.params.use_mmap = use_mmap
self.params.use_mlock = use_mlock
self.params.embedding = embedding
@ -74,7 +77,7 @@ class Llama:
self.tokens_consumed = 0
self.n_batch = min(n_ctx, n_batch)
self.n_threads = n_threads or multiprocessing.cpu_count()
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
@ -661,6 +664,7 @@ class Llama:
f16_kv=self.params.f16_kv,
logits_all=self.params.logits_all,
vocab_only=self.params.vocab_only,
use_mmap=self.params.use_mmap,
use_mlock=self.params.use_mlock,
embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size,
@ -679,6 +683,7 @@ class Llama:
f16_kv=state["f16_kv"],
logits_all=state["logits_all"],
vocab_only=state["vocab_only"],
use_mmap=state["use_mmap"],
use_mlock=state["use_mlock"],
embedding=state["embedding"],
n_threads=state["n_threads"],

View file

@ -10,7 +10,7 @@ def _load_shared_library(lib_base_name):
if sys.platform.startswith("linux"):
lib_ext = ".so"
elif sys.platform == "darwin":
lib_ext = ".dylib"
lib_ext = ".so"
elif sys.platform == "win32":
lib_ext = ".dll"
else:
@ -80,6 +80,7 @@ class llama_context_params(Structure):
c_bool,
), # the llama_eval() call computes all logits, not just the last one
("vocab_only", c_bool), # only load the vocabulary, no weights
("use_mmap", c_bool), # use mmap if possible
("use_mlock", c_bool), # force system to keep model in RAM
("embedding", c_bool), # embedding mode only
# called with a progress value between 0 and 1, pass NULL to disable
@ -102,6 +103,17 @@ def llama_context_default_params() -> llama_context_params:
_lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params
def llama_mmap_supported() -> c_bool:
return _lib.llama_mmap_supported()
_lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool
def llama_mlock_supported() -> c_bool:
return _lib.llama_mlock_supported()
_lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# Various functions for loading a ggml llama model.
# Allocate (almost) all memory needed for the model.
@ -221,7 +233,7 @@ _lib.llama_n_ctx.restype = c_int
def llama_n_embd(ctx: llama_context_p) -> c_int:
return _lib.llama_n_ctx(ctx)
return _lib.llama_n_embd(ctx)
_lib.llama_n_embd.argtypes = [llama_context_p]

View file

@ -0,0 +1,269 @@
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/...
uvicorn fastapi_server_chat:app --reload
```
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import json
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str
n_ctx: int = 2048
n_batch: int = 8
n_threads: int = ((os.cpu_count() or 2) // 2) or 1
f16_kv: bool = True
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True
last_n_tokens_size: int = 64
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
settings = Settings()
llama = llama_cpp.Llama(
settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
embedding=settings.embedding,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
echo: bool = False
stop: List[str] = []
stream: bool = False
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = 40
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
}
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
"n",
"logprobs",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
}
)
)
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel):
model: Optional[str]
input: str
user: Optional[str]
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
}
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
class ChatCompletionRequestMessage(BaseModel):
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
content: str
user: Optional[str] = None
class CreateChatCompletionRequest(BaseModel):
model: Optional[str]
messages: List[ChatCompletionRequestMessage]
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: List[str] = []
max_tokens: int = 128
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
async def create_chat_completion(
request: CreateChatCompletionRequest,
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}
),
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(
server_sent_events(chunks),
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}
if __name__ == "__main__":
import os
import uvicorn
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
)

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "llama_cpp"
version = "0.1.22"
name = "llama_cpp_python"
version = "0.1.30"
description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT"

View file

@ -10,14 +10,18 @@ setup(
description="A Python wrapper for llama.cpp",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.1.22",
version="0.1.30",
author="Andrei Betlen",
author_email="abetlen@gmail.com",
license="MIT",
packages=["llama_cpp"],
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
packages=["llama_cpp", "llama_cpp.server"],
install_requires=[
"typing-extensions>=4.5.0",
],
extras_require={
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
},
python_requires=">=3.7",
classifiers=[
"Programming Language :: Python :: 3",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 53dbba769537e894ead5c6913ab2fd3a4658b738
Subproject commit 180b693a47b6b825288ef9f2c39d24b6eea4eea6