diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 91d8fd6..1b61eec 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -9,8 +9,14 @@ class suppress_stdout_stderr(object): sys = sys os = os + def __init__(self, disable: bool = True): + self.disable = disable + # Oddly enough this works better than the contextlib version def __enter__(self): + if self.disable: + return self + self.outnull_file = self.open(self.os.devnull, "w") self.errnull_file = self.open(self.os.devnull, "w") @@ -31,6 +37,9 @@ class suppress_stdout_stderr(object): return self def __exit__(self, *_): + if self.disable: + return + self.sys.stdout = self.old_stdout self.sys.stderr = self.old_stderr diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9af0588..d3b85c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -296,11 +296,8 @@ class Llama: self.numa = numa if not Llama.__backend_initialized: - if self.verbose: + with suppress_stdout_stderr(disable=self.verbose): llama_cpp.llama_backend_init(self.numa) - else: - with suppress_stdout_stderr(): - llama_cpp.llama_backend_init(self.numa) Llama.__backend_initialized = True self.model_path = model_path @@ -379,38 +376,23 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - if verbose: + with suppress_stdout_stderr(disable=self.verbose): self.model = llama_cpp.llama_load_model_from_file( self.model_path.encode("utf-8"), self.model_params ) - else: - with suppress_stdout_stderr(): - self.model = llama_cpp.llama_load_model_from_file( - self.model_path.encode("utf-8"), self.model_params - ) assert self.model is not None - if verbose: + with suppress_stdout_stderr(disable=self.verbose): self.ctx = llama_cpp.llama_new_context_with_model( self.model, self.context_params ) - else: - with suppress_stdout_stderr(): - self.ctx = llama_cpp.llama_new_context_with_model( - self.model, self.context_params - ) assert self.ctx is not None - if verbose: + with suppress_stdout_stderr(disable=self.verbose): self.batch = llama_cpp.llama_batch_init( self.n_batch, 0, 1 ) - else: - with suppress_stdout_stderr(): - self.batch = llama_cpp.llama_batch_init( - self.n_batch, 0, 1 - ) if self.lora_path: if llama_cpp.llama_model_apply_lora_from_file( @@ -1615,11 +1597,8 @@ class Llama: self.ctx = None def __del__(self): - if self.verbose: + with suppress_stdout_stderr(disable=self.verbose): self._free_model() - else: - with suppress_stdout_stderr(): - self._free_model() def __getstate__(self): return dict(