fix: Update from_pretrained defaults to match hf_hub_download

This commit is contained in:
Andrei Betlen 2024-02-22 00:10:23 -05:00
parent dd22010e85
commit e6d6260a91

View file

@ -1885,8 +1885,9 @@ class Llama:
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
**kwargs: Any,
) -> "Llama":
"""Create a Llama model from a pretrained model name or path.
@ -1945,17 +1946,28 @@ class Llama:
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
local_dir = "."
# download the file
hf_hub_download(
repo_id=repo_id,
local_dir=local_dir,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
)
if local_dir is None:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
local_files_only=True,
)
else:
model_path = os.path.join(local_dir, filename)
return cls(