From e6d6260a91b7831733f7d1f73c7af46a3e8185ed Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 22 Feb 2024 00:10:23 -0500 Subject: [PATCH] fix: Update from_pretrained defaults to match hf_hub_download --- llama_cpp/llama.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9fc4ec2..1226545 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1885,8 +1885,9 @@ class Llama: cls, repo_id: str, filename: Optional[str], - local_dir: Optional[Union[str, os.PathLike[str]]] = ".", + local_dir: Optional[Union[str, os.PathLike[str]]] = None, local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + cache_dir: Optional[Union[str, os.PathLike[str]]] = None, **kwargs: Any, ) -> "Llama": """Create a Llama model from a pretrained model name or path. @@ -1945,18 +1946,29 @@ class Llama: subfolder = str(Path(matching_file).parent) filename = Path(matching_file).name - local_dir = "." - # download the file hf_hub_download( repo_id=repo_id, - local_dir=local_dir, filename=filename, subfolder=subfolder, + local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, + cache_dir=cache_dir, ) - model_path = os.path.join(local_dir, filename) + if local_dir is None: + model_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + cache_dir=cache_dir, + local_files_only=True, + + ) + else: + model_path = os.path.join(local_dir, filename) return cls( model_path=model_path,