Clean up stdout / stderr suppression

This commit is contained in:
Andrei Betlen 2023-11-03 13:02:15 -04:00
parent 4ea7027c41
commit 2ec043af76
2 changed files with 14 additions and 26 deletions

View file

@ -9,8 +9,14 @@ class suppress_stdout_stderr(object):
sys = sys
os = os
def __init__(self, disable: bool = True):
self.disable = disable
# Oddly enough this works better than the contextlib version
def __enter__(self):
if self.disable:
return self
self.outnull_file = self.open(self.os.devnull, "w")
self.errnull_file = self.open(self.os.devnull, "w")
@ -31,6 +37,9 @@ class suppress_stdout_stderr(object):
return self
def __exit__(self, *_):
if self.disable:
return
self.sys.stdout = self.old_stdout
self.sys.stderr = self.old_stderr

View file

@ -296,11 +296,8 @@ class Llama:
self.numa = numa
if not Llama.__backend_initialized:
if self.verbose:
with suppress_stdout_stderr(disable=self.verbose):
llama_cpp.llama_backend_init(self.numa)
else:
with suppress_stdout_stderr():
llama_cpp.llama_backend_init(self.numa)
Llama.__backend_initialized = True
self.model_path = model_path
@ -379,38 +376,23 @@ class Llama:
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
if verbose:
with suppress_stdout_stderr(disable=self.verbose):
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.model_params
)
else:
with suppress_stdout_stderr():
self.model = llama_cpp.llama_load_model_from_file(
self.model_path.encode("utf-8"), self.model_params
)
assert self.model is not None
if verbose:
with suppress_stdout_stderr(disable=self.verbose):
self.ctx = llama_cpp.llama_new_context_with_model(
self.model, self.context_params
)
else:
with suppress_stdout_stderr():
self.ctx = llama_cpp.llama_new_context_with_model(
self.model, self.context_params
)
assert self.ctx is not None
if verbose:
with suppress_stdout_stderr(disable=self.verbose):
self.batch = llama_cpp.llama_batch_init(
self.n_batch, 0, 1
)
else:
with suppress_stdout_stderr():
self.batch = llama_cpp.llama_batch_init(
self.n_batch, 0, 1
)
if self.lora_path:
if llama_cpp.llama_model_apply_lora_from_file(
@ -1615,11 +1597,8 @@ class Llama:
self.ctx = None
def __del__(self):
if self.verbose:
with suppress_stdout_stderr(disable=self.verbose):
self._free_model()
else:
with suppress_stdout_stderr():
self._free_model()
def __getstate__(self):
return dict(