Refactor autotokenizer format to reusable function
This commit is contained in:
parent
b0e597e46e
commit
bbffdaebaa
1 changed files with 22 additions and 20 deletions
|
@ -1,7 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
|
|
||||||
from . import llama_types
|
from . import llama_types
|
||||||
from . import llama
|
from . import llama
|
||||||
|
|
||||||
|
@ -327,6 +329,26 @@ def get_chat_format(name: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path: Union[str, os.PathLike[str]]) -> ChatFormatter:
|
||||||
|
# https://huggingface.co/docs/transformers/main/chat_templating
|
||||||
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
||||||
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
def format_autotokenizer(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
tokenizer.use_default_system_prompt = False
|
||||||
|
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
# Return formatted prompt and eos token by default
|
||||||
|
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
|
||||||
|
|
||||||
|
return format_autotokenizer
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||||
# system prompt is "embedded" in the first message
|
# system prompt is "embedded" in the first message
|
||||||
@register_chat_format("llama-2")
|
@register_chat_format("llama-2")
|
||||||
|
@ -510,26 +532,6 @@ def format_chatml(
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
# eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1
|
|
||||||
@register_chat_format("autotokenizer")
|
|
||||||
def format_autotokenizer(
|
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatFormatterResponse:
|
|
||||||
# https://huggingface.co/docs/transformers/main/chat_templating
|
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
|
||||||
import os
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1
|
|
||||||
print(huggingFaceModel)
|
|
||||||
if not huggingFaceModel:
|
|
||||||
raise Exception("HF_MODEL needs to be set in env to use chat format 'autotokenizer'")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel)
|
|
||||||
tokenizer.use_default_system_prompt = False
|
|
||||||
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
# Return formatted prompt and eos token by default
|
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
|
|
||||||
|
|
||||||
@register_chat_completion_handler("functionary")
|
@register_chat_completion_handler("functionary")
|
||||||
def functionary_chat_handler(
|
def functionary_chat_handler(
|
||||||
|
|
Loading…
Reference in a new issue