Merge branch 'main' into better-server-params-and-fields
This commit is contained in:
commit
7ab08b8d10
5 changed files with 50 additions and 39 deletions
|
@ -90,7 +90,7 @@ This package is under active development and I welcome any contributions.
|
|||
To get started, clone the repository and install the package in development mode:
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
|
||||
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
|
||||
# Will need to be re-run any time vendor/llama.cpp is updated
|
||||
python3 setup.py develop
|
||||
```
|
||||
|
|
|
@ -306,7 +306,7 @@ class Llama:
|
|||
llama_cpp.llama_sample_typical(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
p=llama_cpp.c_float(1.0)
|
||||
p=llama_cpp.c_float(1.0),
|
||||
)
|
||||
llama_cpp.llama_sample_top_p(
|
||||
ctx=self.ctx,
|
||||
|
@ -637,10 +637,7 @@ class Llama:
|
|||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
Llama._logits_to_logprobs(row)
|
||||
for row in self.eval_logits
|
||||
]
|
||||
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
|
|
|
@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
|||
import os
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import app, init_llama
|
||||
from llama_cpp.server.app import create_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_llama()
|
||||
app = create_app()
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
|
|
|
@ -2,18 +2,18 @@ import os
|
|||
import json
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Union, Iterator, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
from typing_extensions import TypedDict, Literal, Annotated
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Depends, FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str = os.environ.get("MODEL", "null")
|
||||
model: str
|
||||
n_ctx: int = 2048
|
||||
n_batch: int = 512
|
||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||
|
@ -27,25 +27,29 @@ class Settings(BaseSettings):
|
|||
vocab_only: bool = False
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
llama: llama_cpp.Llama = None
|
||||
def init_llama(settings: Settings = None):
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
||||
|
||||
def create_app(settings: Optional[Settings] = None):
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(router)
|
||||
global llama
|
||||
llama = llama_cpp.Llama(
|
||||
settings.model,
|
||||
model_path=settings.model,
|
||||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
use_mmap=settings.use_mmap,
|
||||
|
@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
|
|||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache()
|
||||
llama.set_cache(cache)
|
||||
return app
|
||||
|
||||
|
||||
llama_lock = Lock()
|
||||
|
||||
|
||||
def get_llama():
|
||||
with llama_lock:
|
||||
yield llama
|
||||
|
@ -117,8 +125,6 @@ repeat_penalty_field = Field(
|
|||
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]] = Field(
|
||||
default="",
|
||||
|
@ -162,7 +168,7 @@ class CreateCompletionRequest(BaseModel):
|
|||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
|
@ -204,7 +210,7 @@ class CreateEmbeddingRequest(BaseModel):
|
|||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
|
@ -257,7 +263,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=CreateChatCompletionResponse,
|
||||
)
|
||||
|
@ -306,7 +312,7 @@ class ModelList(TypedDict):
|
|||
GetModelResponse = create_model_from_typeddict(ModelList)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=GetModelResponse)
|
||||
@router.get("/v1/models", response_model=GetModelResponse)
|
||||
def get_models() -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
|
|
|
@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
|
|||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
|
|||
def test_llama_pickle():
|
||||
import pickle
|
||||
import tempfile
|
||||
|
||||
fp = tempfile.TemporaryFile()
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
pickle.dump(llama, fp)
|
||||
|
@ -101,6 +104,7 @@ def test_llama_pickle():
|
|||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
||||
|
||||
|
||||
def test_utf8(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||
|
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
|
|||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
@ -143,11 +149,13 @@ def test_utf8(monkeypatch):
|
|||
|
||||
def test_llama_server():
|
||||
from fastapi.testclient import TestClient
|
||||
from llama_cpp.server.app import app, init_llama, Settings
|
||||
s = Settings()
|
||||
s.model = MODEL
|
||||
s.vocab_only = True
|
||||
init_llama(s)
|
||||
from llama_cpp.server.app import create_app, Settings
|
||||
|
||||
settings = Settings(
|
||||
model=MODEL,
|
||||
vocab_only=True,
|
||||
)
|
||||
app = create_app(settings)
|
||||
client = TestClient(app)
|
||||
response = client.get("/v1/models")
|
||||
assert response.json() == {
|
||||
|
|
Loading…
Reference in a new issue