From 7454900733c8ab720b0b018898813150ea48edee Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 27 Jun 2023 17:26:31 -0400 Subject: [PATCH] add cors --- ollama/cmd/cli.py | 26 +++++++++++------ ollama/cmd/server.py | 66 +++++++++++++++++++++++++++++--------------- requirements.txt | 6 ++-- 3 files changed, 62 insertions(+), 36 deletions(-) diff --git a/ollama/cmd/cli.py b/ollama/cmd/cli.py index 7a0c03c1..65a3ad44 100644 --- a/ollama/cmd/cli.py +++ b/ollama/cmd/cli.py @@ -8,24 +8,28 @@ from ollama.cmd import server def main(): parser = ArgumentParser() - parser.add_argument('--models-home', default=Path.home() / '.ollama' / 'models') + parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models") subparsers = parser.add_subparsers() - server.set_parser(subparsers.add_parser('serve')) + server.set_parser(subparsers.add_parser("serve")) - list_parser = subparsers.add_parser('list') + list_parser = subparsers.add_parser("list") list_parser.set_defaults(fn=list) - generate_parser = subparsers.add_parser('generate') - generate_parser.add_argument('model') - generate_parser.add_argument('prompt') + generate_parser = subparsers.add_parser("generate") + generate_parser.add_argument("model") + generate_parser.add_argument("prompt") generate_parser.set_defaults(fn=generate) + add_parser = subparsers.add_parser("add") + add_parser.add_argument("model_file") + generate_parser.set_defaults(fn=add) + args = parser.parse_args() args = vars(args) - fn = args.pop('fn') + fn = args.pop("fn") fn(**args) @@ -38,6 +42,10 @@ def generate(*args, **kwargs): for output in engine.generate(*args, **kwargs): output = json.loads(output) - choices = output.get('choices', []) + choices = output.get("choices", []) if len(choices) > 0: - print(choices[0].get('text', ''), end='') + print(choices[0].get("text", ""), end="") + + +def add(*args, **kwargs): + model.add(*args, **kwargs) diff --git a/ollama/cmd/server.py b/ollama/cmd/server.py index 9e2180cd..5a99e937 100644 --- a/ollama/cmd/server.py +++ b/ollama/cmd/server.py @@ -1,39 +1,59 @@ from aiohttp import web +import aiohttp_cors from ollama import engine def set_parser(parser): - parser.add_argument('--host', default='127.0.0.1') - parser.add_argument('--port', default=7734) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=7734) parser.set_defaults(fn=serve) -def serve(models_home='.', *args, **kwargs): +def serve(models_home=".", *args, **kwargs): app = web.Application() - app.add_routes([ - web.post('/load', load), - web.post('/unload', unload), - web.post('/generate', generate), - ]) - app.update({ - 'llms': {}, - 'models_home': models_home, - }) + cors = aiohttp_cors.setup( + app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }, + ) + + app.add_routes( + [ + web.post("/load", load), + web.post("/unload", unload), + web.post("/generate", generate), + ] + ) + + for route in app.router.routes(): + cors.add(route) + + app.update( + { + "llms": {}, + "models_home": models_home, + } + ) web.run_app(app, **kwargs) async def load(request): body = await request.json() - model = body.get('model') + model = body.get("model") if not model: raise web.HTTPBadRequest() kwargs = { - 'llms': request.app.get('llms'), - 'models_home': request.app.get('models_home'), + "llms": request.app.get("llms"), + "models_home": request.app.get("models_home"), } engine.load(model, **kwargs) @@ -42,21 +62,21 @@ async def load(request): async def unload(request): body = await request.json() - model = body.get('model') + model = body.get("model") if not model: raise web.HTTPBadRequest() - engine.unload(model, llms=request.app.get('llms')) + engine.unload(model, llms=request.app.get("llms")) return web.Response() async def generate(request): body = await request.json() - model = body.get('model') + model = body.get("model") if not model: raise web.HTTPBadRequest() - prompt = body.get('prompt') + prompt = body.get("prompt") if not prompt: raise web.HTTPBadRequest() @@ -64,12 +84,12 @@ async def generate(request): await response.prepare(request) kwargs = { - 'llms': request.app.get('llms'), - 'models_home': request.app.get('models_home'), + "llms": request.app.get("llms"), + "models_home": request.app.get("models_home"), } for output in engine.generate(model, prompt, **kwargs): - await response.write(output.encode('utf-8')) - await response.write(b'\n') + await response.write(output.encode("utf-8")) + await response.write(b"\n") return response diff --git a/requirements.txt b/requirements.txt index cf16ee13..eee30486 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ -click==8.1.3 -Flask==2.3.2 -Flask_Cors==3.0.10 +aiohttp==3.8.4 +aiohttp_cors==0.7.0 llama_cpp_python==0.1.65 pyinstaller==5.13.0 setuptools==65.6.3 -tqdm==4.65.0