diff --git a/server/server.py b/server/server.py index 813836f3..1d5b4c54 100644 --- a/server/server.py +++ b/server/server.py @@ -11,6 +11,37 @@ CORS(app) # enable CORS for all routes llms = {} +@app.route("/load", methods=["POST"]) +def load(): + data = request.get_json() + model = data.get("model") + + if not model: + return Response("Model is required", status=400) + if not os.path.exists(f"../models/{model}.bin"): + return {"error": "The model does not exist."}, 400 + + if model not in llms: + llms[model] = Llama(model_path=f"../models/{model}.bin") + + return Response(status=204) + + +@app.route("/unload", methods=["POST"]) +def unload(): + data = request.get_json() + model = data.get("model") + + if not model: + return Response("Model is required", status=400) + if not os.path.exists(f"../models/{model}.bin"): + return {"error": "The model does not exist."}, 400 + + llms.pop(model, None) + + return Response(status=204) + + @app.route("/generate", methods=["POST"]) def generate(): data = request.get_json() @@ -22,9 +53,10 @@ def generate(): if not prompt: return Response("Prompt is required", status=400) if not os.path.exists(f"../models/{model}.bin"): - return {"error": "The model file does not exist."}, 400 + return {"error": "The model does not exist."}, 400 if model not in llms: + # auto load llms[model] = Llama(model_path=f"../models/{model}.bin") def stream_response():