diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bad75df..30ae3b5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -2,7 +2,6 @@ from __future__ import annotations import os import sys -import abc import uuid import time import multiprocessing @@ -15,7 +14,6 @@ from typing import ( Iterator, Deque, Callable, - Any, ) from collections import deque @@ -31,6 +29,10 @@ from .llama_cache import ( LlamaDiskCache, # type: ignore LlamaRAMCache, # type: ignore ) +from .llama_tokenizer import ( + BaseLlamaTokenizer, + LlamaTokenizer +) import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format @@ -1747,69 +1749,6 @@ class Llama: return longest_prefix -class BaseLlamaTokenizer(abc.ABC): - @abc.abstractmethod - def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]: - raise NotImplementedError - - @abc.abstractmethod - def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes: - raise NotImplementedError - - -class LlamaTokenizer(BaseLlamaTokenizer): - def __init__(self, llama: Llama): - self.llama = llama - self._model = llama._model # type: ignore - - def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]: - return self._model.tokenize(text, add_bos=add_bos, special=special) - - def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes: - if prev_tokens is not None: - return self._model.detokenize(tokens[len(prev_tokens):]) - else: - return self._model.detokenize(tokens) - - def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]: - return self.tokenize( - text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special - ) - - def decode(self, tokens: List[int]) -> str: - return self.detokenize(tokens).decode("utf-8", errors="ignore") - - @classmethod - def from_ggml_file(cls, path: str) -> "LlamaTokenizer": - return cls(Llama(model_path=path, vocab_only=True)) - - -class LlamaHFTokenizer(BaseLlamaTokenizer): - def __init__(self, hf_tokenizer: Any): - self.hf_tokenizer = hf_tokenizer - - def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]: - return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special) - - def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes: - if prev_tokens is not None: - text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") - prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore") - return text[len(prev_text):] - else: - return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer": - try: - from transformers import AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library is required to use the `HFTokenizer`." - "You can install it with `pip install transformers`." - ) - hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path) - return cls(hf_tokenizer) class LlamaState: diff --git a/llama_cpp/llama_tokenizer.py b/llama_cpp/llama_tokenizer.py new file mode 100644 index 0000000..0ad3c3a --- /dev/null +++ b/llama_cpp/llama_tokenizer.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import abc +from typing import ( + List, + Optional, + Any, +) + +import llama_cpp +from llama_cpp.llama_types import List + + +class BaseLlamaTokenizer(abc.ABC): + @abc.abstractmethod + def tokenize( + self, text: bytes, add_bos: bool = True, special: bool = True + ) -> List[int]: + raise NotImplementedError + + @abc.abstractmethod + def detokenize( + self, tokens: List[int], prev_tokens: Optional[List[int]] = None + ) -> bytes: + raise NotImplementedError + + +class LlamaTokenizer(BaseLlamaTokenizer): + def __init__(self, llama: llama_cpp.Llama): + self.llama = llama + self._model = llama._model # type: ignore + + def tokenize( + self, text: bytes, add_bos: bool = True, special: bool = True + ) -> List[int]: + return self._model.tokenize(text, add_bos=add_bos, special=special) + + def detokenize( + self, tokens: List[int], prev_tokens: Optional[List[int]] = None + ) -> bytes: + if prev_tokens is not None: + return self._model.detokenize(tokens[len(prev_tokens) :]) + else: + return self._model.detokenize(tokens) + + def encode( + self, text: str, add_bos: bool = True, special: bool = True + ) -> List[int]: + return self.tokenize( + text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + + def decode(self, tokens: List[int]) -> str: + return self.detokenize(tokens).decode("utf-8", errors="ignore") + + @classmethod + def from_ggml_file(cls, path: str) -> "LlamaTokenizer": + return cls(llama_cpp.Llama(model_path=path, vocab_only=True)) + + +class LlamaHFTokenizer(BaseLlamaTokenizer): + def __init__(self, hf_tokenizer: Any): + self.hf_tokenizer = hf_tokenizer + + def tokenize( + self, text: bytes, add_bos: bool = True, special: bool = True + ) -> List[int]: + return self.hf_tokenizer.encode( + text.decode("utf-8", errors="ignore"), add_special_tokens=special + ) + + def detokenize( + self, tokens: List[int], prev_tokens: Optional[List[int]] = None + ) -> bytes: + if prev_tokens is not None: + text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") + prev_text = self.hf_tokenizer.decode(prev_tokens).encode( + "utf-8", errors="ignore" + ) + return text[len(prev_text) :] + else: + return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer": + try: + from transformers import AutoTokenizer + except ImportError: + raise ImportError( + "The `transformers` library is required to use the `HFTokenizer`." + "You can install it with `pip install transformers`." + ) + hf_tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path + ) + return cls(hf_tokenizer)