feat: Move tokenizer to own module
This commit is contained in:
parent
2ef7ba3aed
commit
b5fca911b5
2 changed files with 100 additions and 65 deletions
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import abc
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
@ -15,7 +14,6 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
Deque,
|
Deque,
|
||||||
Callable,
|
Callable,
|
||||||
Any,
|
|
||||||
)
|
)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
@ -31,6 +29,10 @@ from .llama_cache import (
|
||||||
LlamaDiskCache, # type: ignore
|
LlamaDiskCache, # type: ignore
|
||||||
LlamaRAMCache, # type: ignore
|
LlamaRAMCache, # type: ignore
|
||||||
)
|
)
|
||||||
|
from .llama_tokenizer import (
|
||||||
|
BaseLlamaTokenizer,
|
||||||
|
LlamaTokenizer
|
||||||
|
)
|
||||||
import llama_cpp.llama_cpp as llama_cpp
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
import llama_cpp.llama_chat_format as llama_chat_format
|
import llama_cpp.llama_chat_format as llama_chat_format
|
||||||
|
|
||||||
|
@ -1747,69 +1749,6 @@ class Llama:
|
||||||
return longest_prefix
|
return longest_prefix
|
||||||
|
|
||||||
|
|
||||||
class BaseLlamaTokenizer(abc.ABC):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaTokenizer(BaseLlamaTokenizer):
|
|
||||||
def __init__(self, llama: Llama):
|
|
||||||
self.llama = llama
|
|
||||||
self._model = llama._model # type: ignore
|
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
|
||||||
return self._model.tokenize(text, add_bos=add_bos, special=special)
|
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
|
||||||
if prev_tokens is not None:
|
|
||||||
return self._model.detokenize(tokens[len(prev_tokens):])
|
|
||||||
else:
|
|
||||||
return self._model.detokenize(tokens)
|
|
||||||
|
|
||||||
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
|
|
||||||
return self.tokenize(
|
|
||||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, tokens: List[int]) -> str:
|
|
||||||
return self.detokenize(tokens).decode("utf-8", errors="ignore")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
|
||||||
return cls(Llama(model_path=path, vocab_only=True))
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaHFTokenizer(BaseLlamaTokenizer):
|
|
||||||
def __init__(self, hf_tokenizer: Any):
|
|
||||||
self.hf_tokenizer = hf_tokenizer
|
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
|
||||||
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)
|
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
|
||||||
if prev_tokens is not None:
|
|
||||||
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
|
||||||
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
|
|
||||||
return text[len(prev_text):]
|
|
||||||
else:
|
|
||||||
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
|
|
||||||
try:
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The `transformers` library is required to use the `HFTokenizer`."
|
|
||||||
"You can install it with `pip install transformers`."
|
|
||||||
)
|
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
|
|
||||||
return cls(hf_tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaState:
|
class LlamaState:
|
||||||
|
|
96
llama_cpp/llama_tokenizer.py
Normal file
96
llama_cpp/llama_tokenizer.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
|
import llama_cpp
|
||||||
|
from llama_cpp.llama_types import List
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLlamaTokenizer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def tokenize(
|
||||||
|
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def detokenize(
|
||||||
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
|
) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaTokenizer(BaseLlamaTokenizer):
|
||||||
|
def __init__(self, llama: llama_cpp.Llama):
|
||||||
|
self.llama = llama
|
||||||
|
self._model = llama._model # type: ignore
|
||||||
|
|
||||||
|
def tokenize(
|
||||||
|
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
return self._model.tokenize(text, add_bos=add_bos, special=special)
|
||||||
|
|
||||||
|
def detokenize(
|
||||||
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
|
) -> bytes:
|
||||||
|
if prev_tokens is not None:
|
||||||
|
return self._model.detokenize(tokens[len(prev_tokens) :])
|
||||||
|
else:
|
||||||
|
return self._model.detokenize(tokens)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, text: str, add_bos: bool = True, special: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
return self.tokenize(
|
||||||
|
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
return self.detokenize(tokens).decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
||||||
|
return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaHFTokenizer(BaseLlamaTokenizer):
|
||||||
|
def __init__(self, hf_tokenizer: Any):
|
||||||
|
self.hf_tokenizer = hf_tokenizer
|
||||||
|
|
||||||
|
def tokenize(
|
||||||
|
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
return self.hf_tokenizer.encode(
|
||||||
|
text.decode("utf-8", errors="ignore"), add_special_tokens=special
|
||||||
|
)
|
||||||
|
|
||||||
|
def detokenize(
|
||||||
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
|
) -> bytes:
|
||||||
|
if prev_tokens is not None:
|
||||||
|
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
||||||
|
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
return text[len(prev_text) :]
|
||||||
|
else:
|
||||||
|
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The `transformers` library is required to use the `HFTokenizer`."
|
||||||
|
"You can install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
return cls(hf_tokenizer)
|
Loading…
Add table
Reference in a new issue