diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 60d5194..89e7cb5 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -7,7 +7,9 @@ from ctypes import * # type: ignore from enum import Enum from itertools import islice from typing import ( + Any, Callable, + Dict, Generic, List, Optional, @@ -1399,15 +1401,15 @@ class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} - self._defs = {} + self._defs: Dict[str, Any] = {} - def _format_literal(self, literal): - escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + def _format_literal(self, literal: str): + escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) ) return f'"{escaped}"' - def _add_rule(self, name, rule): + def _add_rule(self, name: str, rule: str): esc_name = INVALID_RULE_CHARS_RE.sub("-", name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name @@ -1419,8 +1421,9 @@ class SchemaConverter: self._rules[key] = rule return key - def visit(self, schema, name): - schema_type = schema.get("type") + def visit(self, schema: Dict[str, Any], name: str) -> str: + schema_type: Optional[str] = schema.get("type") # type: ignore + assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" rule_name = name or "root" if "$defs" in schema: