Fix threading bug. Closes #62

This commit is contained in:
Andrei Betlen 2023-04-12 19:07:53 -04:00
parent 005c78d26c
commit 19598ac4e8

View file

@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
@ -59,6 +60,13 @@ llama = llama_cpp.Llama(
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
@ -101,7 +109,7 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
def create_completion(request: CreateCompletionRequest, llama: llama_cpp.Llama=Depends(get_llama)):
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
@ -146,7 +154,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
def create_embedding(request: CreateEmbeddingRequest, llama: llama_cpp.Llama=Depends(get_llama)):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
@ -200,6 +208,7 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
)
def create_chat_completion(
request: CreateChatCompletionRequest,
llama: llama_cpp.Llama=Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(