add generate command
This commit is contained in:
parent
fcf15df1ef
commit
3ca8f72327
1 changed files with 18 additions and 4 deletions
22
proto.py
22
proto.py
|
@ -31,19 +31,19 @@ def unload(model):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt):
|
def query(model, prompt):
|
||||||
# auto load
|
# auto load
|
||||||
error = load(model)
|
error = load(model)
|
||||||
if error is not None:
|
if error is not None:
|
||||||
return error
|
return error
|
||||||
stream = llms[model](
|
generated = llms[model](
|
||||||
str(prompt), # TODO: optimize prompt based on model
|
str(prompt), # TODO: optimize prompt based on model
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
stop=["Q:", "\n"],
|
stop=["Q:", "\n"],
|
||||||
echo=True,
|
echo=True,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
for output in stream:
|
for output in generated:
|
||||||
yield json.dumps(output)
|
yield json.dumps(output)
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def generate_route_handler():
|
||||||
if not os.path.exists(f"./models/{model}.bin"):
|
if not os.path.exists(f"./models/{model}.bin"):
|
||||||
return {"error": "The model does not exist."}, 400
|
return {"error": "The model does not exist."}, 400
|
||||||
return Response(
|
return Response(
|
||||||
stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
|
stream_with_context(query(model, prompt)), mimetype="text/event-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,5 +117,19 @@ def serve(port, debug):
|
||||||
app.run(host="0.0.0.0", port=port, debug=debug)
|
app.run(host="0.0.0.0", port=port, debug=debug)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
|
||||||
|
@click.option("--prompt", default="", help="The prompt for the model")
|
||||||
|
def generate(model, prompt):
|
||||||
|
if prompt == "":
|
||||||
|
prompt = input("Prompt: ")
|
||||||
|
output = ""
|
||||||
|
for generated in query(model, prompt):
|
||||||
|
generated_json = json.loads(generated)
|
||||||
|
text = generated_json["choices"][0]["text"]
|
||||||
|
output += text
|
||||||
|
print(f"\r{output}", end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|
Loading…
Reference in a new issue