add generate command

This commit is contained in:
Bruce MacDonald 2023-06-26 13:00:40 -04:00
parent fcf15df1ef
commit 3ca8f72327

View file

@ -31,19 +31,19 @@ def unload(model):
return None
def generate(model, prompt):
def query(model, prompt):
# auto load
error = load(model)
if error is not None:
return error
stream = llms[model](
generated = llms[model](
str(prompt), # TODO: optimize prompt based on model
max_tokens=4096,
stop=["Q:", "\n"],
echo=True,
stream=True,
)
for output in stream:
for output in generated:
yield json.dumps(output)
@ -91,7 +91,7 @@ def generate_route_handler():
if not os.path.exists(f"./models/{model}.bin"):
return {"error": "The model does not exist."}, 400
return Response(
stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
stream_with_context(query(model, prompt)), mimetype="text/event-stream"
)
@ -117,5 +117,19 @@ def serve(port, debug):
app.run(host="0.0.0.0", port=port, debug=debug)
@cli.command()
@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
@click.option("--prompt", default="", help="The prompt for the model")
def generate(model, prompt):
if prompt == "":
prompt = input("Prompt: ")
output = ""
for generated in query(model, prompt):
generated_json = json.loads(generated)
text = generated_json["choices"][0]["text"]
output += text
print(f"\r{output}", end="", flush=True)
if __name__ == "__main__":
cli()