fix: Update from_pretrained defaults to match hf_hub_download
This commit is contained in:
parent
dd22010e85
commit
e6d6260a91
1 changed files with 17 additions and 5 deletions
|
@ -1885,8 +1885,9 @@ class Llama:
|
|||
cls,
|
||||
repo_id: str,
|
||||
filename: Optional[str],
|
||||
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
|
||||
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
||||
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "Llama":
|
||||
"""Create a Llama model from a pretrained model name or path.
|
||||
|
@ -1945,18 +1946,29 @@ class Llama:
|
|||
subfolder = str(Path(matching_file).parent)
|
||||
filename = Path(matching_file).name
|
||||
|
||||
local_dir = "."
|
||||
|
||||
# download the file
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=local_dir,
|
||||
filename=filename,
|
||||
subfolder=subfolder,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
model_path = os.path.join(local_dir, filename)
|
||||
if local_dir is None:
|
||||
model_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
subfolder=subfolder,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
|
||||
)
|
||||
else:
|
||||
model_path = os.path.join(local_dir, filename)
|
||||
|
||||
return cls(
|
||||
model_path=model_path,
|
||||
|
|
Loading…
Reference in a new issue