feat: Add .close()
method to Llama
class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model This commit introduces a `close` method to both `Llama` and `_LlamaModel`, allowing users to explicitly free the model from RAM/VRAM. The previous implementation relied on the destructor of `_LlamaModel` to free the model. However, in Python, the timing of destructor calls is unclear—for instance, the `del` statement does not guarantee immediate invocation of the destructor. This commit provides an explicit method to release the model, which works immediately and allows the user to load another model without memory issues. Additionally, this commit implements a context manager in the `Llama` class, enabling the automatic closure of the `Llama` object when used with the `with` statement. * feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch This commit enables automatic resource management by implementing the `ContextManager` protocol in `_LlamaModel`, `_LlamaContext`, and `_LlamaBatch`. This ensures that resources are properly managed and released within a `with` statement, enhancing robustness and safety in resource handling. * feat: add ExitStack for Llama's internal class closure This update implements ExitStack to manage and close internal classes in Llama, enhancing efficient and safe resource management. * Use contextlib ExitStack and closing * Explicitly free model when closing resources on server --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
dbcf64cf07
commit
320a5d7ea5
3 changed files with 50 additions and 30 deletions
|
@ -9,6 +9,7 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
@ -27,9 +28,6 @@ class _LlamaModel:
|
||||||
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
||||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||||
|
|
||||||
_llama_free_model = None
|
|
||||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -40,8 +38,7 @@ class _LlamaModel:
|
||||||
self.path_model = path_model
|
self.path_model = path_model
|
||||||
self.params = params
|
self.params = params
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self._exit_stack = ExitStack()
|
||||||
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
|
|
||||||
|
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
|
@ -56,11 +53,17 @@ class _LlamaModel:
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError(f"Failed to load model from file: {path_model}")
|
raise ValueError(f"Failed to load model from file: {path_model}")
|
||||||
|
|
||||||
def __del__(self):
|
def free_model():
|
||||||
if self.model is not None and self._llama_free_model is not None:
|
if self.model is None:
|
||||||
self._llama_free_model(self.model)
|
return
|
||||||
|
llama_cpp.llama_free_model(self.model)
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
|
self._exit_stack.callback(free_model)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._exit_stack.close()
|
||||||
|
|
||||||
def vocab_type(self) -> int:
|
def vocab_type(self) -> int:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
return llama_cpp.llama_vocab_type(self.model)
|
return llama_cpp.llama_vocab_type(self.model)
|
||||||
|
@ -257,8 +260,6 @@ class _LlamaContext:
|
||||||
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
||||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||||
|
|
||||||
_llama_free = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -269,24 +270,28 @@ class _LlamaContext:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.params = params
|
self.params = params
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self._exit_stack = ExitStack()
|
||||||
|
|
||||||
self._llama_free = llama_cpp._lib.llama_free # type: ignore
|
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
|
|
||||||
assert self.model.model is not None
|
assert self.model.model is not None
|
||||||
|
|
||||||
self.ctx = llama_cpp.llama_new_context_with_model(
|
self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
|
||||||
self.model.model, self.params
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ctx is None:
|
if self.ctx is None:
|
||||||
raise ValueError("Failed to create llama_context")
|
raise ValueError("Failed to create llama_context")
|
||||||
|
|
||||||
def __del__(self):
|
def free_ctx():
|
||||||
if self.ctx is not None and self._llama_free is not None:
|
if self.ctx is None:
|
||||||
self._llama_free(self.ctx)
|
return
|
||||||
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
|
|
||||||
|
self._exit_stack.callback(free_ctx)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._exit_stack.close()
|
||||||
|
|
||||||
def n_ctx(self) -> int:
|
def n_ctx(self) -> int:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
return llama_cpp.llama_n_ctx(self.ctx)
|
return llama_cpp.llama_n_ctx(self.ctx)
|
||||||
|
@ -501,8 +506,6 @@ class _LlamaContext:
|
||||||
|
|
||||||
|
|
||||||
class _LlamaBatch:
|
class _LlamaBatch:
|
||||||
_llama_batch_free = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
||||||
):
|
):
|
||||||
|
@ -510,19 +513,24 @@ class _LlamaBatch:
|
||||||
self.embd = embd
|
self.embd = embd
|
||||||
self.n_seq_max = n_seq_max
|
self.n_seq_max = n_seq_max
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self._exit_stack = ExitStack()
|
||||||
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
|
|
||||||
|
|
||||||
self.batch = None
|
self.batch = None
|
||||||
self.batch = llama_cpp.llama_batch_init(
|
self.batch = llama_cpp.llama_batch_init(
|
||||||
self._n_tokens, self.embd, self.n_seq_max
|
self._n_tokens, self.embd, self.n_seq_max
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def free_batch():
|
||||||
if self.batch is not None and self._llama_batch_free is not None:
|
if self.batch is None:
|
||||||
self._llama_batch_free(self.batch)
|
return
|
||||||
|
llama_cpp.llama_batch_free(self.batch)
|
||||||
self.batch = None
|
self.batch = None
|
||||||
|
|
||||||
|
self._exit_stack.callback(free_batch)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._exit_stack.close()
|
||||||
|
|
||||||
def n_tokens(self) -> int:
|
def n_tokens(self) -> int:
|
||||||
assert self.batch is not None
|
assert self.batch is not None
|
||||||
return self.batch.n_tokens
|
return self.batch.n_tokens
|
||||||
|
|
|
@ -9,7 +9,9 @@ import ctypes
|
||||||
import typing
|
import typing
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import warnings
|
import warnings
|
||||||
|
import contextlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
|
@ -21,6 +23,7 @@ from typing import (
|
||||||
Deque,
|
Deque,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -350,9 +353,11 @@ class Llama:
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise ValueError(f"Model path does not exist: {model_path}")
|
raise ValueError(f"Model path does not exist: {model_path}")
|
||||||
|
|
||||||
self._model = _LlamaModel(
|
self._stack = contextlib.ExitStack()
|
||||||
|
|
||||||
|
self._model = self._stack.enter_context(contextlib.closing(_LlamaModel(
|
||||||
path_model=self.model_path, params=self.model_params, verbose=self.verbose
|
path_model=self.model_path, params=self.model_params, verbose=self.verbose
|
||||||
)
|
)))
|
||||||
|
|
||||||
# Override tokenizer
|
# Override tokenizer
|
||||||
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
|
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
|
||||||
|
@ -364,18 +369,18 @@ class Llama:
|
||||||
self.context_params.n_ctx = self._model.n_ctx_train()
|
self.context_params.n_ctx = self._model.n_ctx_train()
|
||||||
self.context_params.n_batch = self.n_batch
|
self.context_params.n_batch = self.n_batch
|
||||||
|
|
||||||
self._ctx = _LlamaContext(
|
self._ctx = self._stack.enter_context(contextlib.closing(_LlamaContext(
|
||||||
model=self._model,
|
model=self._model,
|
||||||
params=self.context_params,
|
params=self.context_params,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
)
|
)))
|
||||||
|
|
||||||
self._batch = _LlamaBatch(
|
self._batch = self._stack.enter_context(contextlib.closing(_LlamaBatch(
|
||||||
n_tokens=self.n_batch,
|
n_tokens=self.n_batch,
|
||||||
embd=0,
|
embd=0,
|
||||||
n_seq_max=self.context_params.n_ctx,
|
n_seq_max=self.context_params.n_ctx,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
)
|
)))
|
||||||
|
|
||||||
if self.lora_path:
|
if self.lora_path:
|
||||||
if self._model.apply_lora_from_file(
|
if self._model.apply_lora_from_file(
|
||||||
|
@ -1959,6 +1964,10 @@ class Llama:
|
||||||
"""Return the pooling type."""
|
"""Return the pooling type."""
|
||||||
return self._ctx.pooling_type()
|
return self._ctx.pooling_type()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Explicitly free the model from memory."""
|
||||||
|
self._stack.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logits_to_logprobs(
|
def logits_to_logprobs(
|
||||||
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
||||||
|
|
|
@ -44,6 +44,8 @@ class LlamaProxy:
|
||||||
if self._current_model is not None:
|
if self._current_model is not None:
|
||||||
return self._current_model
|
return self._current_model
|
||||||
|
|
||||||
|
if self._current_model:
|
||||||
|
self._current_model.close()
|
||||||
self._current_model = None
|
self._current_model = None
|
||||||
|
|
||||||
settings = self._model_settings_dict[model]
|
settings = self._model_settings_dict[model]
|
||||||
|
@ -65,6 +67,7 @@ class LlamaProxy:
|
||||||
|
|
||||||
def free(self):
|
def free(self):
|
||||||
if self._current_model:
|
if self._current_model:
|
||||||
|
self._current_model.close()
|
||||||
del self._current_model
|
del self._current_model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in a new issue