feat: Add .close() method to Llama class to explicitly free model from memory (#1513)

* feat: add explicit methods to free model

This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.

The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.

This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.

Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.

* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch

This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.

* feat: add ExitStack for Llama's internal class closure

This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.

* Use contextlib ExitStack and closing

* Explicitly free model when closing resources on server

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
Junpei Kawamoto 2024-06-13 02:16:14 -06:00 committed by GitHub
parent dbcf64cf07
commit 320a5d7ea5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 50 additions and 30 deletions

View file

@ -9,6 +9,7 @@ from typing import (
Sequence,
)
from dataclasses import dataclass, field
from contextlib import ExitStack
import numpy as np
import numpy.typing as npt
@ -27,9 +28,6 @@ class _LlamaModel:
"""Intermediate Python wrapper for a llama.cpp llama_model.
NOTE: For stability it's recommended you use the Llama class instead."""
_llama_free_model = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
def __init__(
self,
*,
@ -40,8 +38,7 @@ class _LlamaModel:
self.path_model = path_model
self.params = params
self.verbose = verbose
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore
self._exit_stack = ExitStack()
self.model = None
@ -56,11 +53,17 @@ class _LlamaModel:
if self.model is None:
raise ValueError(f"Failed to load model from file: {path_model}")
def __del__(self):
if self.model is not None and self._llama_free_model is not None:
self._llama_free_model(self.model)
def free_model():
if self.model is None:
return
llama_cpp.llama_free_model(self.model)
self.model = None
self._exit_stack.callback(free_model)
def close(self):
self._exit_stack.close()
def vocab_type(self) -> int:
assert self.model is not None
return llama_cpp.llama_vocab_type(self.model)
@ -257,8 +260,6 @@ class _LlamaContext:
"""Intermediate Python wrapper for a llama.cpp llama_context.
NOTE: For stability it's recommended you use the Llama class instead."""
_llama_free = None
def __init__(
self,
*,
@ -269,24 +270,28 @@ class _LlamaContext:
self.model = model
self.params = params
self.verbose = verbose
self._exit_stack = ExitStack()
self._llama_free = llama_cpp._lib.llama_free # type: ignore
self.ctx = None
assert self.model.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(
self.model.model, self.params
)
self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
if self.ctx is None:
raise ValueError("Failed to create llama_context")
def __del__(self):
if self.ctx is not None and self._llama_free is not None:
self._llama_free(self.ctx)
def free_ctx():
if self.ctx is None:
return
llama_cpp.llama_free(self.ctx)
self.ctx = None
self._exit_stack.callback(free_ctx)
def close(self):
self._exit_stack.close()
def n_ctx(self) -> int:
assert self.ctx is not None
return llama_cpp.llama_n_ctx(self.ctx)
@ -501,8 +506,6 @@ class _LlamaContext:
class _LlamaBatch:
_llama_batch_free = None
def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
):
@ -510,19 +513,24 @@ class _LlamaBatch:
self.embd = embd
self.n_seq_max = n_seq_max
self.verbose = verbose
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
self._exit_stack = ExitStack()
self.batch = None
self.batch = llama_cpp.llama_batch_init(
self._n_tokens, self.embd, self.n_seq_max
)
def __del__(self):
if self.batch is not None and self._llama_batch_free is not None:
self._llama_batch_free(self.batch)
def free_batch():
if self.batch is None:
return
llama_cpp.llama_batch_free(self.batch)
self.batch = None
self._exit_stack.callback(free_batch)
def close(self):
self._exit_stack.close()
def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens

View file

@ -9,7 +9,9 @@ import ctypes
import typing
import fnmatch
import warnings
import contextlib
import multiprocessing
from types import TracebackType
from typing import (
List,
@ -21,6 +23,7 @@ from typing import (
Deque,
Callable,
Dict,
Type,
)
from collections import deque
from pathlib import Path
@ -350,9 +353,11 @@ class Llama:
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
self._model = _LlamaModel(
self._stack = contextlib.ExitStack()
self._model = self._stack.enter_context(contextlib.closing(_LlamaModel(
path_model=self.model_path, params=self.model_params, verbose=self.verbose
)
)))
# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
@ -364,18 +369,18 @@ class Llama:
self.context_params.n_ctx = self._model.n_ctx_train()
self.context_params.n_batch = self.n_batch
self._ctx = _LlamaContext(
self._ctx = self._stack.enter_context(contextlib.closing(_LlamaContext(
model=self._model,
params=self.context_params,
verbose=self.verbose,
)
)))
self._batch = _LlamaBatch(
self._batch = self._stack.enter_context(contextlib.closing(_LlamaBatch(
n_tokens=self.n_batch,
embd=0,
n_seq_max=self.context_params.n_ctx,
verbose=self.verbose,
)
)))
if self.lora_path:
if self._model.apply_lora_from_file(
@ -1959,6 +1964,10 @@ class Llama:
"""Return the pooling type."""
return self._ctx.pooling_type()
def close(self) -> None:
"""Explicitly free the model from memory."""
self._stack.close()
@staticmethod
def logits_to_logprobs(
logits: Union[npt.NDArray[np.single], List], axis: int = -1

View file

@ -44,6 +44,8 @@ class LlamaProxy:
if self._current_model is not None:
return self._current_model
if self._current_model:
self._current_model.close()
self._current_model = None
settings = self._model_settings_dict[model]
@ -65,6 +67,7 @@ class LlamaProxy:
def free(self):
if self._current_model:
self._current_model.close()
del self._current_model
@staticmethod