load and unload model endpoints

This commit is contained in:
Bruce MacDonald 2023-06-23 14:47:57 -04:00
parent b97b9504e6
commit ebec1c61db

View file

@ -11,6 +11,37 @@ CORS(app) # enable CORS for all routes
llms = {} llms = {}
@app.route("/load", methods=["POST"])
def load():
data = request.get_json()
model = data.get("model")
if not model:
return Response("Model is required", status=400)
if not os.path.exists(f"../models/{model}.bin"):
return {"error": "The model does not exist."}, 400
if model not in llms:
llms[model] = Llama(model_path=f"../models/{model}.bin")
return Response(status=204)
@app.route("/unload", methods=["POST"])
def unload():
data = request.get_json()
model = data.get("model")
if not model:
return Response("Model is required", status=400)
if not os.path.exists(f"../models/{model}.bin"):
return {"error": "The model does not exist."}, 400
llms.pop(model, None)
return Response(status=204)
@app.route("/generate", methods=["POST"]) @app.route("/generate", methods=["POST"])
def generate(): def generate():
data = request.get_json() data = request.get_json()
@ -22,9 +53,10 @@ def generate():
if not prompt: if not prompt:
return Response("Prompt is required", status=400) return Response("Prompt is required", status=400)
if not os.path.exists(f"../models/{model}.bin"): if not os.path.exists(f"../models/{model}.bin"):
return {"error": "The model file does not exist."}, 400 return {"error": "The model does not exist."}, 400
if model not in llms: if model not in llms:
# auto load
llms[model] = Llama(model_path=f"../models/{model}.bin") llms[model] = Llama(model_path=f"../models/{model}.bin")
def stream_response(): def stream_response():