spinner on generate

This commit is contained in:
Bruce MacDonald 2023-06-29 16:34:50 -04:00
parent 55ab5e60db
commit faa1ce3195
3 changed files with 38 additions and 1 deletions

View file

@ -2,6 +2,7 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
from yaspin import yaspin
from ollama import model, engine from ollama import model, engine
from ollama.cmd import server from ollama.cmd import server
@ -75,9 +76,16 @@ def generate(*args, **kwargs):
def generate_oneshot(*args, **kwargs): def generate_oneshot(*args, **kwargs):
print(flush=True) print(flush=True)
spinner = yaspin()
spinner.start()
spinner_running = True
for output in engine.generate(*args, **kwargs): for output in engine.generate(*args, **kwargs):
choices = output.get("choices", []) choices = output.get("choices", [])
if len(choices) > 0: if len(choices) > 0:
if spinner_running:
spinner.stop()
spinner_running = False
print("\r", end="") # move cursor back to beginning of line again
print(choices[0].get("text", ""), end="", flush=True) print(choices[0].get("text", ""), end="", flush=True)
# end with a new line # end with a new line

30
poetry.lock generated
View file

@ -622,6 +622,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "termcolor"
version = "2.3.0"
description = "ANSI color formatting for output in terminal"
optional = false
python-versions = ">=3.7"
files = [
{file = "termcolor-2.3.0-py3-none-any.whl", hash = "sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475"},
{file = "termcolor-2.3.0.tar.gz", hash = "sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a"},
]
[package.extras]
tests = ["pytest", "pytest-cov"]
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.65.0" version = "4.65.0"
@ -773,7 +787,21 @@ files = [
idna = ">=2.0" idna = ">=2.0"
multidict = ">=4.0" multidict = ">=4.0"
[[package]]
name = "yaspin"
version = "2.3.0"
description = "Yet Another Terminal Spinner"
optional = false
python-versions = ">=3.7.2,<4.0.0"
files = [
{file = "yaspin-2.3.0-py3-none-any.whl", hash = "sha256:17b5548479b3d5b30adec7a87ffcdcddb403d14a2bb86fbcee97f37951e13427"},
{file = "yaspin-2.3.0.tar.gz", hash = "sha256:547afd1a9700ac3a29a9f5591c70343bef186ed5dfb5e545a9bb0c77e561a1c9"},
]
[package.dependencies]
termcolor = ">=2.2,<3.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "45d11efe54e87646d1a469782d628c6687815328bc52040e9fc622a508df7684" content-hash = "76a09c53830f5066a8fbdfb93b90c8847e3e60dbeb8989fc1c167a8e3a41d90d"

View file

@ -16,6 +16,7 @@ jinja2 = "^3.1.2"
requests = "^2.31.0" requests = "^2.31.0"
tqdm = "^4.65.0" tqdm = "^4.65.0"
validators = "^0.20.0" validators = "^0.20.0"
yaspin = "^2.3.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]