diff --git a/ollama/prompt.py b/ollama/prompt.py index 5e329e3e..84a16f33 100644 --- a/ollama/prompt.py +++ b/ollama/prompt.py @@ -1,19 +1,9 @@ -from os import path -from difflib import SequenceMatcher +from difflib import get_close_matches from jinja2 import Environment, PackageLoader def template(name, prompt): - best_ratio = 0 - best_template = '' - environment = Environment(loader=PackageLoader(__name__, 'templates')) - for template in environment.list_templates(): - base, _ = path.splitext(template) - ratio = SequenceMatcher(None, path.basename(name).lower(), base).ratio() - if ratio > best_ratio: - best_ratio = ratio - best_template = template - - template = environment.get_template(best_template) + best_templates = get_close_matches(name, environment.list_templates(), n=1, cutoff=0) + template = environment.get_template(best_templates.pop()) return template.render(prompt=prompt)