Refactor autotokenizer format to reusable function

This commit is contained in:
Andrei Betlen 2023-11-06 09:07:27 -05:00
parent b0e597e46e
commit bbffdaebaa

View file

@ -1,7 +1,9 @@
from __future__ import annotations from __future__ import annotations
import os
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
from . import llama_types from . import llama_types
from . import llama from . import llama
@ -327,6 +329,26 @@ def get_chat_format(name: str):
) )
def hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path: Union[str, os.PathLike[str]]) -> ChatFormatter:
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
tokenizer.use_default_system_prompt = False
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
return format_autotokenizer
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message # system prompt is "embedded" in the first message
@register_chat_format("llama-2") @register_chat_format("llama-2")
@ -510,26 +532,6 @@ def format_chatml(
_prompt = _format_chatml(system_message, _messages, _sep) _prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt) return ChatFormatterResponse(prompt=_prompt)
# eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1
@register_chat_format("autotokenizer")
def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
import os
from transformers import AutoTokenizer
huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1
print(huggingFaceModel)
if not huggingFaceModel:
raise Exception("HF_MODEL needs to be set in env to use chat format 'autotokenizer'")
tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel)
tokenizer.use_default_system_prompt = False
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
@register_chat_completion_handler("functionary") @register_chat_completion_handler("functionary")
def functionary_chat_handler( def functionary_chat_handler(