fix: Update from_pretrained defaults to match hf_hub_download
This commit is contained in:
parent
dd22010e85
commit
e6d6260a91
1 changed files with 17 additions and 5 deletions
|
@ -1885,8 +1885,9 @@ class Llama:
|
||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
filename: Optional[str],
|
filename: Optional[str],
|
||||||
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
|
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||||
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "Llama":
|
) -> "Llama":
|
||||||
"""Create a Llama model from a pretrained model name or path.
|
"""Create a Llama model from a pretrained model name or path.
|
||||||
|
@ -1945,18 +1946,29 @@ class Llama:
|
||||||
subfolder = str(Path(matching_file).parent)
|
subfolder = str(Path(matching_file).parent)
|
||||||
filename = Path(matching_file).name
|
filename = Path(matching_file).name
|
||||||
|
|
||||||
local_dir = "."
|
|
||||||
|
|
||||||
# download the file
|
# download the file
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
local_dir=local_dir,
|
|
||||||
filename=filename,
|
filename=filename,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
local_dir=local_dir,
|
||||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = os.path.join(local_dir, filename)
|
if local_dir is None:
|
||||||
|
model_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
subfolder=subfolder,
|
||||||
|
local_dir=local_dir,
|
||||||
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=True,
|
||||||
|
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
|
Loading…
Reference in a new issue