Fix threading bug. Closes #62

This commit is contained in:
Andrei Betlen 2023-04-12 19:07:53 -04:00
parent 005c78d26c
commit 19598ac4e8

View file

@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import json import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict from typing_extensions import TypedDict
import llama_cpp import llama_cpp
from fastapi import FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -59,6 +60,13 @@ llama = llama_cpp.Llama(
n_ctx=settings.n_ctx, n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size, last_n_tokens_size=settings.last_n_tokens_size,
) )
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
@ -101,7 +109,7 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion(request: CreateCompletionRequest): def create_completion(request: CreateCompletionRequest, llama: llama_cpp.Llama=Depends(get_llama)):
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt) request.prompt = "".join(request.prompt)
@ -146,7 +154,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
def create_embedding(request: CreateEmbeddingRequest): def create_embedding(request: CreateEmbeddingRequest, llama: llama_cpp.Llama=Depends(get_llama)):
return llama.create_embedding(**request.dict(exclude={"model", "user"})) return llama.create_embedding(**request.dict(exclude={"model", "user"}))
@ -200,6 +208,7 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
) )
def create_chat_completion( def create_chat_completion(
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,
llama: llama_cpp.Llama=Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion( completion_or_chunks = llama.create_chat_completion(
**request.dict( **request.dict(