From 468377b0e210c205944edafb0325779e87347581 Mon Sep 17 00:00:00 2001 From: Lucas Doyle Date: Fri, 28 Apr 2023 22:43:37 -0700 Subject: [PATCH 1/3] llama_cpp server: app is now importable, still runnable as a module --- llama_cpp/server/__init__.py | 0 llama_cpp/server/__main__.py | 281 ++--------------------------------- llama_cpp/server/app.py | 266 +++++++++++++++++++++++++++++++++ 3 files changed, 279 insertions(+), 268 deletions(-) create mode 100644 llama_cpp/server/__init__.py create mode 100644 llama_cpp/server/app.py diff --git a/llama_cpp/server/__init__.py b/llama_cpp/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index af6cc38..dd4767f 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -5,283 +5,28 @@ To run this example: ```bash pip install fastapi uvicorn sse-starlette export MODEL=../models/7B/... -uvicorn fastapi_server_chat:app --reload +``` + +Then run: +``` +uvicorn llama_cpp.server.app:app --reload +``` + +or + +``` +python3 -m llama_cpp.server ``` Then visit http://localhost:8000/docs to see the interactive API docs. """ import os -import json -from threading import Lock -from typing import List, Optional, Literal, Union, Iterator, Dict -from typing_extensions import TypedDict - -import llama_cpp - -from fastapi import Depends, FastAPI -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict -from sse_starlette.sse import EventSourceResponse - - -class Settings(BaseSettings): - model: str - n_ctx: int = 2048 - n_batch: int = 512 - n_threads: int = max((os.cpu_count() or 2) // 2, 1) - f16_kv: bool = True - use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... - use_mmap: bool = True - embedding: bool = True - last_n_tokens_size: int = 64 - logits_all: bool = False - cache: bool = False # WARNING: This is an experimental feature - - -app = FastAPI( - title="🦙 llama.cpp Python API", - version="0.0.1", -) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -settings = Settings() -llama = llama_cpp.Llama( - settings.model, - f16_kv=settings.f16_kv, - use_mlock=settings.use_mlock, - use_mmap=settings.use_mmap, - embedding=settings.embedding, - logits_all=settings.logits_all, - n_threads=settings.n_threads, - n_batch=settings.n_batch, - n_ctx=settings.n_ctx, - last_n_tokens_size=settings.last_n_tokens_size, -) -if settings.cache: - cache = llama_cpp.LlamaCache() - llama.set_cache(cache) -llama_lock = Lock() - - -def get_llama(): - with llama_lock: - yield llama - - -class CreateCompletionRequest(BaseModel): - prompt: Union[str, List[str]] - suffix: Optional[str] = Field(None) - max_tokens: int = 16 - temperature: float = 0.8 - top_p: float = 0.95 - echo: bool = False - stop: Optional[List[str]] = [] - stream: bool = False - - # ignored or currently unsupported - model: Optional[str] = Field(None) - n: Optional[int] = 1 - logprobs: Optional[int] = Field(None) - presence_penalty: Optional[float] = 0 - frequency_penalty: Optional[float] = 0 - best_of: Optional[int] = 1 - logit_bias: Optional[Dict[str, float]] = Field(None) - user: Optional[str] = Field(None) - - # llama.cpp specific parameters - top_k: int = 40 - repeat_penalty: float = 1.1 - - class Config: - schema_extra = { - "example": { - "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", - "stop": ["\n", "###"], - } - } - - -CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) - - -@app.post( - "/v1/completions", - response_model=CreateCompletionResponse, -) -def create_completion( - request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) -): - if isinstance(request.prompt, list): - request.prompt = "".join(request.prompt) - - completion_or_chunks = llama( - **request.dict( - exclude={ - "model", - "n", - "frequency_penalty", - "presence_penalty", - "best_of", - "logit_bias", - "user", - } - ) - ) - if request.stream: - chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore - return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) - completion: llama_cpp.Completion = completion_or_chunks # type: ignore - return completion - - -class CreateEmbeddingRequest(BaseModel): - model: Optional[str] - input: str - user: Optional[str] - - class Config: - schema_extra = { - "example": { - "input": "The food was delicious and the waiter...", - } - } - - -CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) - - -@app.post( - "/v1/embeddings", - response_model=CreateEmbeddingResponse, -) -def create_embedding( - request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) -): - return llama.create_embedding(**request.dict(exclude={"model", "user"})) - - -class ChatCompletionRequestMessage(BaseModel): - role: Union[Literal["system"], Literal["user"], Literal["assistant"]] - content: str - user: Optional[str] = None - - -class CreateChatCompletionRequest(BaseModel): - model: Optional[str] - messages: List[ChatCompletionRequestMessage] - temperature: float = 0.8 - top_p: float = 0.95 - stream: bool = False - stop: Optional[List[str]] = [] - max_tokens: int = 128 - - # ignored or currently unsupported - model: Optional[str] = Field(None) - n: Optional[int] = 1 - presence_penalty: Optional[float] = 0 - frequency_penalty: Optional[float] = 0 - logit_bias: Optional[Dict[str, float]] = Field(None) - user: Optional[str] = Field(None) - - # llama.cpp specific parameters - repeat_penalty: float = 1.1 - - class Config: - schema_extra = { - "example": { - "messages": [ - ChatCompletionRequestMessage( - role="system", content="You are a helpful assistant." - ), - ChatCompletionRequestMessage( - role="user", content="What is the capital of France?" - ), - ] - } - } - - -CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) - - -@app.post( - "/v1/chat/completions", - response_model=CreateChatCompletionResponse, -) -def create_chat_completion( - request: CreateChatCompletionRequest, - llama: llama_cpp.Llama = Depends(get_llama), -) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: - completion_or_chunks = llama.create_chat_completion( - **request.dict( - exclude={ - "model", - "n", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - } - ), - ) - - if request.stream: - - async def server_sent_events( - chat_chunks: Iterator[llama_cpp.ChatCompletionChunk], - ): - for chat_chunk in chat_chunks: - yield dict(data=json.dumps(chat_chunk)) - yield dict(data="[DONE]") - - chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore - - return EventSourceResponse( - server_sent_events(chunks), - ) - completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore - return completion - - -class ModelData(TypedDict): - id: str - object: Literal["model"] - owned_by: str - permissions: List[str] - - -class ModelList(TypedDict): - object: Literal["list"] - data: List[ModelData] - - -GetModelResponse = create_model_from_typeddict(ModelList) - - -@app.get("/v1/models", response_model=GetModelResponse) -def get_models() -> ModelList: - return { - "object": "list", - "data": [ - { - "id": llama.model_path, - "object": "model", - "owned_by": "me", - "permissions": [], - } - ], - } +import uvicorn +from llama_cpp.server.app import app if __name__ == "__main__": - import os - import uvicorn uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py new file mode 100644 index 0000000..d296e14 --- /dev/null +++ b/llama_cpp/server/app.py @@ -0,0 +1,266 @@ +import os +import json +from threading import Lock +from typing import List, Optional, Literal, Union, Iterator, Dict +from typing_extensions import TypedDict + +import llama_cpp + +from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict +from sse_starlette.sse import EventSourceResponse + + +class Settings(BaseSettings): + model: str = os.environ["MODEL"] + n_ctx: int = 2048 + n_batch: int = 512 + n_threads: int = max((os.cpu_count() or 2) // 2, 1) + f16_kv: bool = True + use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... + use_mmap: bool = True + embedding: bool = True + last_n_tokens_size: int = 64 + logits_all: bool = False + cache: bool = False # WARNING: This is an experimental feature + + +app = FastAPI( + title="🦙 llama.cpp Python API", + version="0.0.1", +) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +settings = Settings() +llama = llama_cpp.Llama( + settings.model, + f16_kv=settings.f16_kv, + use_mlock=settings.use_mlock, + use_mmap=settings.use_mmap, + embedding=settings.embedding, + logits_all=settings.logits_all, + n_threads=settings.n_threads, + n_batch=settings.n_batch, + n_ctx=settings.n_ctx, + last_n_tokens_size=settings.last_n_tokens_size, +) +if settings.cache: + cache = llama_cpp.LlamaCache() + llama.set_cache(cache) +llama_lock = Lock() + + +def get_llama(): + with llama_lock: + yield llama + + +class CreateCompletionRequest(BaseModel): + prompt: Union[str, List[str]] + suffix: Optional[str] = Field(None) + max_tokens: int = 16 + temperature: float = 0.8 + top_p: float = 0.95 + echo: bool = False + stop: Optional[List[str]] = [] + stream: bool = False + + # ignored or currently unsupported + model: Optional[str] = Field(None) + n: Optional[int] = 1 + logprobs: Optional[int] = Field(None) + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + best_of: Optional[int] = 1 + logit_bias: Optional[Dict[str, float]] = Field(None) + user: Optional[str] = Field(None) + + # llama.cpp specific parameters + top_k: int = 40 + repeat_penalty: float = 1.1 + + class Config: + schema_extra = { + "example": { + "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", + "stop": ["\n", "###"], + } + } + + +CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) + + +@app.post( + "/v1/completions", + response_model=CreateCompletionResponse, +) +def create_completion( + request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) +): + if isinstance(request.prompt, list): + request.prompt = "".join(request.prompt) + + completion_or_chunks = llama( + **request.dict( + exclude={ + "model", + "n", + "frequency_penalty", + "presence_penalty", + "best_of", + "logit_bias", + "user", + } + ) + ) + if request.stream: + chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore + return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) + completion: llama_cpp.Completion = completion_or_chunks # type: ignore + return completion + + +class CreateEmbeddingRequest(BaseModel): + model: Optional[str] + input: str + user: Optional[str] + + class Config: + schema_extra = { + "example": { + "input": "The food was delicious and the waiter...", + } + } + + +CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) + + +@app.post( + "/v1/embeddings", + response_model=CreateEmbeddingResponse, +) +def create_embedding( + request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) +): + return llama.create_embedding(**request.dict(exclude={"model", "user"})) + + +class ChatCompletionRequestMessage(BaseModel): + role: Union[Literal["system"], Literal["user"], Literal["assistant"]] + content: str + user: Optional[str] = None + + +class CreateChatCompletionRequest(BaseModel): + model: Optional[str] + messages: List[ChatCompletionRequestMessage] + temperature: float = 0.8 + top_p: float = 0.95 + stream: bool = False + stop: Optional[List[str]] = [] + max_tokens: int = 128 + + # ignored or currently unsupported + model: Optional[str] = Field(None) + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + logit_bias: Optional[Dict[str, float]] = Field(None) + user: Optional[str] = Field(None) + + # llama.cpp specific parameters + repeat_penalty: float = 1.1 + + class Config: + schema_extra = { + "example": { + "messages": [ + ChatCompletionRequestMessage( + role="system", content="You are a helpful assistant." + ), + ChatCompletionRequestMessage( + role="user", content="What is the capital of France?" + ), + ] + } + } + + +CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) + + +@app.post( + "/v1/chat/completions", + response_model=CreateChatCompletionResponse, +) +def create_chat_completion( + request: CreateChatCompletionRequest, + llama: llama_cpp.Llama = Depends(get_llama), +) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: + completion_or_chunks = llama.create_chat_completion( + **request.dict( + exclude={ + "model", + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + } + ), + ) + + if request.stream: + + async def server_sent_events( + chat_chunks: Iterator[llama_cpp.ChatCompletionChunk], + ): + for chat_chunk in chat_chunks: + yield dict(data=json.dumps(chat_chunk)) + yield dict(data="[DONE]") + + chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore + + return EventSourceResponse( + server_sent_events(chunks), + ) + completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore + return completion + + +class ModelData(TypedDict): + id: str + object: Literal["model"] + owned_by: str + permissions: List[str] + + +class ModelList(TypedDict): + object: Literal["list"] + data: List[ModelData] + + +GetModelResponse = create_model_from_typeddict(ModelList) + + +@app.get("/v1/models", response_model=GetModelResponse) +def get_models() -> ModelList: + return { + "object": "list", + "data": [ + { + "id": llama.model_path, + "object": "model", + "owned_by": "me", + "permissions": [], + } + ], + } From 6d8db9d017b6b6b68bcff79cce5e770705ef016a Mon Sep 17 00:00:00 2001 From: Lucas Doyle Date: Fri, 28 Apr 2023 23:26:07 -0700 Subject: [PATCH 2/3] tests: simple test for server module --- llama_cpp/server/app.py | 2 + poetry.lock | 95 ++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + tests/test_llama.py | 21 +++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index d296e14..2c50fcb 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -24,6 +24,7 @@ class Settings(BaseSettings): last_n_tokens_size: int = 64 logits_all: bool = False cache: bool = False # WARNING: This is an experimental feature + vocab_only: bool = False app = FastAPI( @@ -49,6 +50,7 @@ llama = llama_cpp.Llama( n_batch=settings.n_batch, n_ctx=settings.n_ctx, last_n_tokens_size=settings.last_n_tokens_size, + vocab_only=settings.vocab_only, ) if settings.cache: cache = llama_cpp.LlamaCache() diff --git a/poetry.lock b/poetry.lock index 8a74d2f..a505168 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,25 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. + +[[package]] +name = "anyio" +version = "3.6.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "dev" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"}, + {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] +trio = ["trio (>=0.16,<0.22)"] [[package]] name = "attrs" @@ -398,6 +419,64 @@ colorama = ">=0.4" [package.extras] async = ["aiofiles (>=0.7,<1.0)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "0.17.0" +description = "A minimal low-level HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpcore-0.17.0-py3-none-any.whl", hash = "sha256:0fdfea45e94f0c9fd96eab9286077f9ff788dd186635ae61b312693e4d943599"}, + {file = "httpcore-0.17.0.tar.gz", hash = "sha256:cc045a3241afbf60ce056202301b4d8b6af08845e3294055eb26b09913ef903c"}, +] + +[package.dependencies] +anyio = ">=3.0,<5.0" +certifi = "*" +h11 = ">=0.13,<0.15" +sniffio = ">=1.0.0,<2.0.0" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + +[[package]] +name = "httpx" +version = "0.24.0" +description = "The next generation HTTP client." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpx-0.24.0-py3-none-any.whl", hash = "sha256:447556b50c1921c351ea54b4fe79d91b724ed2b027462ab9a329465d147d5a4e"}, + {file = "httpx-0.24.0.tar.gz", hash = "sha256:507d676fc3e26110d41df7d35ebd8b3b8585052450f4097401c9be59d928c63e"}, +] + +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.18.0" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] + [[package]] name = "idna" version = "3.4" @@ -1232,6 +1311,18 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -1367,4 +1458,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9" +content-hash = "aa15e57300668bd23c051b4cd87bec4c1a58dcccd2f2b4767579fea7f2c5fa41" diff --git a/pyproject.toml b/pyproject.toml index 798fcaf..362899b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ mkdocs = "^1.4.2" mkdocstrings = {extras = ["python"], version = "^0.20.0"} mkdocs-material = "^9.1.4" pytest = "^7.2.2" +httpx = "^0.24.0" [build-system] requires = [ diff --git a/tests/test_llama.py b/tests/test_llama.py index 4727d90..9110286 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -128,3 +128,24 @@ def test_utf8(monkeypatch): n = 0 # reset completion = llama.create_completion("", max_tokens=1) assert completion["choices"][0]["text"] == "" + + +def test_llama_server(): + from fastapi.testclient import TestClient + import os + os.environ["MODEL"] = MODEL + os.environ["VOCAB_ONLY"] = "true" + from llama_cpp.server.app import app + client = TestClient(app) + response = client.get("/v1/models") + assert response.json() == { + "object": "list", + "data": [ + { + "id": MODEL, + "object": "model", + "owned_by": "me", + "permissions": [], + } + ], + } From efe8e6f8795eb2f92db22b841a40ad41fb053fe1 Mon Sep 17 00:00:00 2001 From: Lucas Doyle Date: Fri, 28 Apr 2023 23:47:36 -0700 Subject: [PATCH 3/3] llama_cpp server: slight refactor to init_llama function Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py This allows the test to be less brittle by not needing to mess with os.environ, then importing the app --- llama_cpp/server/__main__.py | 3 ++- llama_cpp/server/app.py | 45 +++++++++++++++++++----------------- tests/test_llama.py | 9 ++++---- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index dd4767f..f57d68c 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,9 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs. import os import uvicorn -from llama_cpp.server.app import app +from llama_cpp.server.app import app, init_llama if __name__ == "__main__": + init_llama() uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 2c50fcb..92b023c 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -13,7 +13,7 @@ from sse_starlette.sse import EventSourceResponse class Settings(BaseSettings): - model: str = os.environ["MODEL"] + model: str = os.environ.get("MODEL", "null") n_ctx: int = 2048 n_batch: int = 512 n_threads: int = max((os.cpu_count() or 2) // 2, 1) @@ -38,31 +38,34 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -settings = Settings() -llama = llama_cpp.Llama( - settings.model, - f16_kv=settings.f16_kv, - use_mlock=settings.use_mlock, - use_mmap=settings.use_mmap, - embedding=settings.embedding, - logits_all=settings.logits_all, - n_threads=settings.n_threads, - n_batch=settings.n_batch, - n_ctx=settings.n_ctx, - last_n_tokens_size=settings.last_n_tokens_size, - vocab_only=settings.vocab_only, -) -if settings.cache: - cache = llama_cpp.LlamaCache() - llama.set_cache(cache) + +llama: llama_cpp.Llama = None +def init_llama(settings: Settings = None): + if settings is None: + settings = Settings() + global llama + llama = llama_cpp.Llama( + settings.model, + f16_kv=settings.f16_kv, + use_mlock=settings.use_mlock, + use_mmap=settings.use_mmap, + embedding=settings.embedding, + logits_all=settings.logits_all, + n_threads=settings.n_threads, + n_batch=settings.n_batch, + n_ctx=settings.n_ctx, + last_n_tokens_size=settings.last_n_tokens_size, + vocab_only=settings.vocab_only, + ) + if settings.cache: + cache = llama_cpp.LlamaCache() + llama.set_cache(cache) + llama_lock = Lock() - - def get_llama(): with llama_lock: yield llama - class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) diff --git a/tests/test_llama.py b/tests/test_llama.py index 9110286..c3f69cc 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -132,10 +132,11 @@ def test_utf8(monkeypatch): def test_llama_server(): from fastapi.testclient import TestClient - import os - os.environ["MODEL"] = MODEL - os.environ["VOCAB_ONLY"] = "true" - from llama_cpp.server.app import app + from llama_cpp.server.app import app, init_llama, Settings + s = Settings() + s.model = MODEL + s.vocab_only = True + init_llama(s) client = TestClient(app) response = client.get("/v1/models") assert response.json() == {