Add lora_path parameter to Llama model
This commit is contained in:
parent
35abf89552
commit
eb7f278cc6
1 changed files with 15 additions and 0 deletions
|
@ -39,6 +39,7 @@ class Llama:
|
||||||
n_threads: Optional[int] = None,
|
n_threads: Optional[int] = None,
|
||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
last_n_tokens_size: int = 64,
|
last_n_tokens_size: int = 64,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""Load a llama.cpp model from `model_path`.
|
"""Load a llama.cpp model from `model_path`.
|
||||||
|
@ -57,6 +58,7 @@ class Llama:
|
||||||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||||
|
lora_path: Path to a LoRA file to apply to the model.
|
||||||
verbose: Print verbose output to stderr.
|
verbose: Print verbose output to stderr.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -108,6 +110,17 @@ class Llama:
|
||||||
self.model_path.encode("utf-8"), self.params
|
self.model_path.encode("utf-8"), self.params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.lora_path = None
|
||||||
|
if lora_path:
|
||||||
|
self.lora_path = lora_path
|
||||||
|
if llama_cpp.llama_apply_lora_from_file(
|
||||||
|
self.ctx,
|
||||||
|
self.lora_path.encode("utf-8"),
|
||||||
|
self.model_path.encode("utf-8"),
|
||||||
|
llama_cpp.c_int(self.n_threads),
|
||||||
|
):
|
||||||
|
raise RuntimeError(f"Failed to apply LoRA from path: {self.lora_path}")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
|
@ -802,6 +815,7 @@ class Llama:
|
||||||
last_n_tokens_size=self.last_n_tokens_size,
|
last_n_tokens_size=self.last_n_tokens_size,
|
||||||
n_batch=self.n_batch,
|
n_batch=self.n_batch,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
|
lora_path=self.lora_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
@ -819,6 +833,7 @@ class Llama:
|
||||||
n_threads=state["n_threads"],
|
n_threads=state["n_threads"],
|
||||||
n_batch=state["n_batch"],
|
n_batch=state["n_batch"],
|
||||||
last_n_tokens_size=state["last_n_tokens_size"],
|
last_n_tokens_size=state["last_n_tokens_size"],
|
||||||
|
lora_path=state["lora_path"],
|
||||||
verbose=state["verbose"],
|
verbose=state["verbose"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue