Merge branch 'main' into better-server-params-and-fields

This commit is contained in:
Andrei 2023-05-01 22:45:57 -04:00 committed by GitHub
commit 7ab08b8d10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 39 deletions

View file

@ -90,7 +90,7 @@ This package is under active development and I welcome any contributions.
To get started, clone the repository and install the package in development mode:
```bash
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
# Will need to be re-run any time vendor/llama.cpp is updated
python3 setup.py develop
```

View file

@ -306,7 +306,7 @@ class Llama:
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=llama_cpp.c_float(1.0)
p=llama_cpp.c_float(1.0),
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
@ -637,10 +637,7 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [
Llama._logits_to_logprobs(row)
for row in self.eval_logits
]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):

View file

@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
import os
import uvicorn
from llama_cpp.server.app import app, init_llama
from llama_cpp.server.app import create_app
if __name__ == "__main__":
init_llama()
app = create_app()
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

View file

@ -2,18 +2,18 @@ import os
import json
from threading import Lock
from typing import List, Optional, Union, Iterator, Dict
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal, Annotated
import llama_cpp
from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str = os.environ.get("MODEL", "null")
model: str
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@ -27,25 +27,29 @@ class Settings(BaseSettings):
vocab_only: bool = False
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
router = APIRouter()
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
llama: Optional[llama_cpp.Llama] = None
def create_app(settings: Optional[Settings] = None):
if settings is None:
settings = Settings()
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router)
global llama
llama = llama_cpp.Llama(
settings.model,
model_path=settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
return app
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
@ -117,8 +125,6 @@ repeat_penalty_field = Field(
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="",
@ -162,7 +168,7 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
@router.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
@ -204,7 +210,7 @@ class CreateEmbeddingRequest(BaseModel):
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
@router.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
@ -257,7 +263,7 @@ class CreateChatCompletionRequest(BaseModel):
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
@router.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
@ -306,7 +312,7 @@ class ModelList(TypedDict):
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
@router.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",

View file

@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
@ -101,6 +104,7 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -143,11 +149,13 @@ def test_utf8(monkeypatch):
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
from llama_cpp.server.app import create_app, Settings
settings = Settings(
model=MODEL,
vocab_only=True,
)
app = create_app(settings)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {