diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 49a00b2..4360506 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs. """ import os import json +from threading import Lock from typing import List, Optional, Literal, Union, Iterator, Dict from typing_extensions import TypedDict import llama_cpp -from fastapi import FastAPI +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from sse_starlette.sse import EventSourceResponse @@ -59,6 +60,13 @@ llama = llama_cpp.Llama( n_ctx=settings.n_ctx, last_n_tokens_size=settings.last_n_tokens_size, ) +llama_lock = Lock() + + +def get_llama(): + with llama_lock: + yield llama + class CreateCompletionRequest(BaseModel): @@ -101,7 +109,7 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) "/v1/completions", response_model=CreateCompletionResponse, ) -def create_completion(request: CreateCompletionRequest): +def create_completion(request: CreateCompletionRequest, llama: llama_cpp.Llama=Depends(get_llama)): if isinstance(request.prompt, list): request.prompt = "".join(request.prompt) @@ -146,7 +154,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) "/v1/embeddings", response_model=CreateEmbeddingResponse, ) -def create_embedding(request: CreateEmbeddingRequest): +def create_embedding(request: CreateEmbeddingRequest, llama: llama_cpp.Llama=Depends(get_llama)): return llama.create_embedding(**request.dict(exclude={"model", "user"})) @@ -200,6 +208,7 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet ) def create_chat_completion( request: CreateChatCompletionRequest, + llama: llama_cpp.Llama=Depends(get_llama), ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: completion_or_chunks = llama.create_chat_completion( **request.dict(