add generate command

This commit is contained in:
Bruce MacDonald 2023-06-26 13:00:40 -04:00
parent fcf15df1ef
commit 3ca8f72327

View file

@ -31,19 +31,19 @@ def unload(model):
return None return None
def generate(model, prompt): def query(model, prompt):
# auto load # auto load
error = load(model) error = load(model)
if error is not None: if error is not None:
return error return error
stream = llms[model]( generated = llms[model](
str(prompt), # TODO: optimize prompt based on model str(prompt), # TODO: optimize prompt based on model
max_tokens=4096, max_tokens=4096,
stop=["Q:", "\n"], stop=["Q:", "\n"],
echo=True, echo=True,
stream=True, stream=True,
) )
for output in stream: for output in generated:
yield json.dumps(output) yield json.dumps(output)
@ -91,7 +91,7 @@ def generate_route_handler():
if not os.path.exists(f"./models/{model}.bin"): if not os.path.exists(f"./models/{model}.bin"):
return {"error": "The model does not exist."}, 400 return {"error": "The model does not exist."}, 400
return Response( return Response(
stream_with_context(generate(model, prompt)), mimetype="text/event-stream" stream_with_context(query(model, prompt)), mimetype="text/event-stream"
) )
@ -117,5 +117,19 @@ def serve(port, debug):
app.run(host="0.0.0.0", port=port, debug=debug) app.run(host="0.0.0.0", port=port, debug=debug)
@cli.command()
@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
@click.option("--prompt", default="", help="The prompt for the model")
def generate(model, prompt):
if prompt == "":
prompt = input("Prompt: ")
output = ""
for generated in query(model, prompt):
generated_json = json.loads(generated)
text = generated_json["choices"][0]["text"]
output += text
print(f"\r{output}", end="", flush=True)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()