Merge branch 'main' of https://github.com/abetlen/llama-cpp-python
This commit is contained in:
commit
be2c961bc9
8 changed files with 203 additions and 113 deletions
|
@ -2,7 +2,11 @@ cmake_minimum_required(VERSION 3.4...3.22)
|
|||
|
||||
project(llama_cpp)
|
||||
|
||||
if (UNIX)
|
||||
option(FORCE_CMAKE "Force CMake build of Python bindings" OFF)
|
||||
|
||||
set(FORCE_CMAKE $ENV{FORCE_CMAKE})
|
||||
|
||||
if (UNIX AND NOT FORCE_CMAKE)
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so
|
||||
COMMAND make libllama.so
|
||||
|
|
|
@ -105,12 +105,16 @@ python3 setup.py develop
|
|||
- __call__
|
||||
- create_chat_completion
|
||||
- set_cache
|
||||
- save_state
|
||||
- load_state
|
||||
- token_bos
|
||||
- token_eos
|
||||
show_root_heading: true
|
||||
|
||||
::: llama_cpp.LlamaCache
|
||||
|
||||
::: llama_cpp.LlamaState
|
||||
|
||||
::: llama_cpp.llama_cpp
|
||||
options:
|
||||
show_if_no_docstring: true
|
||||
|
|
|
@ -4,7 +4,7 @@ import uuid
|
|||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||
from collections import deque
|
||||
|
||||
from . import llama_cpp
|
||||
|
@ -12,12 +12,53 @@ from .llama_types import *
|
|||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model.
|
||||
"""Cache for a llama.cpp model."""
|
||||
|
||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||
completion. It does not actually cache the results."""
|
||||
def __init__(self):
|
||||
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
|
||||
|
||||
pass
|
||||
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
|
||||
return [
|
||||
key
|
||||
for _, key in sorted(
|
||||
((len(key), key) for key in self.cache_state.keys()), reverse=True
|
||||
)
|
||||
]
|
||||
|
||||
def _find_key(
|
||||
self, key: Tuple[llama_cpp.llama_token, ...]
|
||||
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
|
||||
for k in self._sorted_keys():
|
||||
if key[: len(k)] == k:
|
||||
return k
|
||||
return None
|
||||
|
||||
def __getitem__(
|
||||
self, key: Sequence[llama_cpp.llama_token]
|
||||
) -> Optional["LlamaState"]:
|
||||
_key = self._find_key(tuple(key))
|
||||
if _key is None:
|
||||
return None
|
||||
return self.cache_state[_key]
|
||||
|
||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||
return self._find_key(tuple(key)) is not None
|
||||
|
||||
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
||||
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
||||
self.cache_state[tuple(key)] = value
|
||||
|
||||
|
||||
class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[float]],
|
||||
llama_state,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.llama_state = llama_state
|
||||
|
||||
|
||||
class Llama:
|
||||
|
@ -37,8 +78,10 @@ class Llama:
|
|||
use_mlock: bool = False,
|
||||
embedding: bool = False,
|
||||
n_threads: Optional[int] = None,
|
||||
n_batch: int = 8,
|
||||
n_batch: int = 512,
|
||||
last_n_tokens_size: int = 64,
|
||||
lora_base: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Load a llama.cpp model from `model_path`.
|
||||
|
@ -57,6 +100,8 @@ class Llama:
|
|||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
||||
lora_path: Path to a LoRA file to apply to the model.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
@ -75,32 +120,22 @@ class Llama:
|
|||
self.params.f16_kv = f16_kv
|
||||
self.params.logits_all = logits_all
|
||||
self.params.vocab_only = vocab_only
|
||||
self.params.use_mmap = use_mmap
|
||||
self.params.use_mmap = use_mmap if lora_path is None else False
|
||||
self.params.use_mlock = use_mlock
|
||||
self.params.embedding = embedding
|
||||
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.last_n_tokens_data = deque(
|
||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
|
||||
maxlen=self.last_n_tokens_size,
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.tokens: List[llama_cpp.llama_token] = []
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||
### because it does not take into account stop tokens which have been processed by the model.
|
||||
self._completion_bytes: List[bytes] = []
|
||||
self._cache: Optional[LlamaCache] = None
|
||||
###
|
||||
self.cache: Optional[LlamaCache] = None
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
self.lora_base = lora_base
|
||||
self.lora_path = lora_path
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"Model path does not exist: {model_path}")
|
||||
|
||||
|
@ -108,6 +143,21 @@ class Llama:
|
|||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
|
||||
assert self.ctx is not None
|
||||
|
||||
if self.lora_path:
|
||||
if llama_cpp.llama_apply_lora_from_file(
|
||||
self.ctx,
|
||||
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
|
||||
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
|
||||
if self.lora_base is not None
|
||||
else llama_cpp.c_char_p(0),
|
||||
llama_cpp.c_int(self.n_threads),
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
|
||||
|
||||
|
@ -158,18 +208,12 @@ class Llama:
|
|||
Args:
|
||||
cache: The cache to set.
|
||||
"""
|
||||
self._cache = cache
|
||||
self.cache = cache
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.last_n_tokens_data.extend(
|
||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.tokens.clear()
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits.clear()
|
||||
self.eval_tokens.clear()
|
||||
self.eval_logits.clear()
|
||||
|
||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||
"""Evaluate a list of tokens.
|
||||
|
@ -181,32 +225,28 @@ class Llama:
|
|||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||
self.n_tokens = len(batch)
|
||||
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
||||
n_tokens = len(batch)
|
||||
return_code = llama_cpp.llama_eval(
|
||||
ctx=self.ctx,
|
||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||
n_tokens=llama_cpp.c_int(self.n_tokens),
|
||||
n_past=llama_cpp.c_int(self.n_past),
|
||||
n_tokens=llama_cpp.c_int(n_tokens),
|
||||
n_past=llama_cpp.c_int(n_past),
|
||||
n_threads=llama_cpp.c_int(self.n_threads),
|
||||
)
|
||||
if int(return_code) != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
self.tokens.extend(batch)
|
||||
self.last_n_tokens_data.extend(batch)
|
||||
self.tokens_consumed += len(batch)
|
||||
self.eval_tokens.extend(batch)
|
||||
if self.params.logits_all:
|
||||
self.all_logits.extend(self._logits())
|
||||
|
||||
def _logits(self) -> List[List[float]]:
|
||||
"""Return the logits from the last call to llama_eval."""
|
||||
assert self.ctx is not None
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
rows = self.n_tokens if self.params.logits_all else 1
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||
return logits
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
rows = n_tokens
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [
|
||||
[logits_view[i * cols + j] for j in range(cols)]
|
||||
for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
@ -227,10 +267,13 @@ class Llama:
|
|||
The sampled token.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||
return llama_cpp.llama_sample_top_p_top_k(
|
||||
ctx=self.ctx,
|
||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
*self.last_n_tokens_data
|
||||
*last_n_tokens_data
|
||||
),
|
||||
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
||||
top_k=llama_cpp.c_int(top_k),
|
||||
|
@ -270,18 +313,17 @@ class Llama:
|
|||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
### HACK
|
||||
|
||||
if (
|
||||
reset
|
||||
and self._cache
|
||||
and len(self.tokens) > 0
|
||||
and self.tokens == tokens[: len(self.tokens)]
|
||||
and len(self.eval_tokens) > 0
|
||||
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[len(self.tokens) :]
|
||||
###
|
||||
tokens = tokens[len(self.eval_tokens) :]
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
while True:
|
||||
|
@ -398,20 +440,10 @@ class Llama:
|
|||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
### HACK
|
||||
reset: bool = True
|
||||
_prompt: bytes = prompt.encode("utf-8")
|
||||
_completion: bytes = b"".join(self._completion_bytes)
|
||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||
if self.cache and prompt_tokens in self.cache:
|
||||
if self.verbose:
|
||||
print("completion cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
_prompt = _prompt[len(_completion) :]
|
||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||
self._completion_bytes.append(_prompt)
|
||||
else:
|
||||
self._completion_bytes = [prompt.encode("utf-8")]
|
||||
###
|
||||
print("cache hit", file=sys.stderr)
|
||||
self.load_state(self.cache[prompt_tokens])
|
||||
|
||||
finish_reason = "length"
|
||||
for token in self.generate(
|
||||
|
@ -420,12 +452,18 @@ class Llama:
|
|||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
reset=reset,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
print("cache miss", file=sys.stderr)
|
||||
self.cache[prompt_tokens] = self.save_state()
|
||||
|
||||
completion_tokens.append(token)
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
|
@ -450,9 +488,6 @@ class Llama:
|
|||
break
|
||||
text = all_text[: len(all_text) - longest]
|
||||
returned_characters += len(text[start:])
|
||||
### HACK
|
||||
self._completion_bytes.append(text[start:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -474,9 +509,6 @@ class Llama:
|
|||
break
|
||||
|
||||
if stream:
|
||||
### HACK
|
||||
self._completion_bytes.append(text[returned_characters:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -493,9 +525,6 @@ class Llama:
|
|||
}
|
||||
return
|
||||
|
||||
### HACK
|
||||
self._completion_bytes.append(text)
|
||||
###
|
||||
text_str = text.decode("utf-8", errors="ignore")
|
||||
|
||||
if echo:
|
||||
|
@ -518,7 +547,7 @@ class Llama:
|
|||
]
|
||||
all_logprobs = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
for row in self.all_logits
|
||||
for row in self.eval_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
|
@ -802,6 +831,8 @@ class Llama:
|
|||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
n_batch=self.n_batch,
|
||||
n_threads=self.n_threads,
|
||||
lora_base=self.lora_base,
|
||||
lora_path=self.lora_path,
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -819,9 +850,31 @@ class Llama:
|
|||
n_threads=state["n_threads"],
|
||||
n_batch=state["n_batch"],
|
||||
last_n_tokens_size=state["last_n_tokens_size"],
|
||||
lora_base=state["lora_base"],
|
||||
lora_path=state["lora_path"],
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
|
||||
def save_state(self) -> LlamaState:
|
||||
assert self.ctx is not None
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
||||
raise RuntimeError("Failed to copy llama state data")
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
llama_state=llama_state,
|
||||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
||||
@staticmethod
|
||||
def token_eos() -> llama_cpp.llama_token:
|
||||
"""Return the end-of-sequence token."""
|
||||
|
|
|
@ -114,7 +114,12 @@ LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
|
|||
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(
|
||||
4
|
||||
) # tok_embeddings.weight and output.weight are F16
|
||||
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes.c_int(7) # except 1d tensors
|
||||
|
||||
# Functions
|
||||
|
||||
|
@ -167,31 +172,34 @@ _lib.llama_free.restype = None
|
|||
|
||||
# TODO: not great API - very likely to change
|
||||
# Returns 0 on success
|
||||
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int:
|
||||
return _lib.llama_model_quantize(fname_inp, fname_out, itype)
|
||||
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||
def llama_model_quantize(
|
||||
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
|
||||
) -> c_int:
|
||||
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
|
||||
|
||||
|
||||
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
|
||||
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
|
||||
_lib.llama_model_quantize.restype = c_int
|
||||
|
||||
|
||||
# Returns the KV cache that will contain the context for the
|
||||
# ongoing prediction with the model.
|
||||
def llama_get_kv_cache(ctx: llama_context_p):
|
||||
return _lib.llama_get_kv_cache(ctx)
|
||||
# Apply a LoRA adapter to a loaded model
|
||||
# path_base_model is the path to a higher quality model to use as a base for
|
||||
# the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
# The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
# will be applied on top of the previous one
|
||||
# Returns 0 on success
|
||||
def llama_apply_lora_from_file(
|
||||
ctx: llama_context_p,
|
||||
path_lora: ctypes.c_char_p,
|
||||
path_base_model: ctypes.c_char_p,
|
||||
n_threads: c_int,
|
||||
) -> c_int:
|
||||
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
|
||||
|
||||
|
||||
_lib.llama_get_kv_cache.argtypes = [llama_context_p]
|
||||
_lib.llama_get_kv_cache.restype = POINTER(c_uint8)
|
||||
|
||||
|
||||
# Returns the size of the KV cache
|
||||
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
|
||||
return _lib.llama_get_kv_cache_size(ctx)
|
||||
|
||||
|
||||
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
|
||||
_lib.llama_get_kv_cache_size.restype = c_size_t
|
||||
_lib.llama_apply_lora_from_file.argtypes = [llama_context_p, c_char_p, c_char_p, c_int]
|
||||
_lib.llama_apply_lora_from_file.restype = c_int
|
||||
|
||||
|
||||
# Returns the number of tokens in the KV cache
|
||||
|
@ -203,15 +211,34 @@ _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
|
|||
_lib.llama_get_kv_cache_token_count.restype = c_int
|
||||
|
||||
|
||||
# Sets the KV cache containing the current context for the model
|
||||
def llama_set_kv_cache(
|
||||
ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int
|
||||
):
|
||||
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
|
||||
# Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
|
||||
return _lib.llama_get_state_size(ctx)
|
||||
|
||||
|
||||
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
|
||||
_lib.llama_set_kv_cache.restype = None
|
||||
_lib.llama_get_state_size.argtypes = [llama_context_p]
|
||||
_lib.llama_get_state_size.restype = c_size_t
|
||||
|
||||
|
||||
# Copies the state to the specified destination address.
|
||||
# Destination needs to have allocated enough memory.
|
||||
# Returns the number of bytes copied
|
||||
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
|
||||
return _lib.llama_copy_state_data(ctx, dest)
|
||||
|
||||
|
||||
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_copy_state_data.restype = c_size_t
|
||||
|
||||
|
||||
# Set the state reading from the specified address
|
||||
# Returns the number of bytes read
|
||||
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
|
||||
return _lib.llama_set_state_data(ctx, src)
|
||||
|
||||
|
||||
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_set_state_data.restype = c_size_t
|
||||
|
||||
|
||||
# Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
|
|
|
@ -28,10 +28,11 @@ from sse_starlette.sse import EventSourceResponse
|
|||
class Settings(BaseSettings):
|
||||
model: str
|
||||
n_ctx: int = 2048
|
||||
n_batch: int = 8
|
||||
n_threads: int = ((os.cpu_count() or 2) // 2) or 1
|
||||
n_batch: int = 512
|
||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||
f16_kv: bool = True
|
||||
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
|
||||
use_mmap: bool = True
|
||||
embedding: bool = True
|
||||
last_n_tokens_size: int = 64
|
||||
logits_all: bool = False
|
||||
|
@ -54,6 +55,7 @@ llama = llama_cpp.Llama(
|
|||
settings.model,
|
||||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
use_mmap=settings.use_mmap,
|
||||
embedding=settings.embedding,
|
||||
logits_all=settings.logits_all,
|
||||
n_threads=settings.n_threads,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llama_cpp_python"
|
||||
version = "0.1.34"
|
||||
version = "0.1.38"
|
||||
description = "Python bindings for the llama.cpp library"
|
||||
authors = ["Andrei Betlen <abetlen@gmail.com>"]
|
||||
license = "MIT"
|
||||
|
|
2
setup.py
2
setup.py
|
@ -10,7 +10,7 @@ setup(
|
|||
description="A Python wrapper for llama.cpp",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version="0.1.34",
|
||||
version="0.1.38",
|
||||
author="Andrei Betlen",
|
||||
author_email="abetlen@gmail.com",
|
||||
license="MIT",
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit e95b6554b493e71a0275764342e09bd5784a7026
|
||||
Subproject commit 4afcc378698e057fcde64e23eb664e5af8dd6956
|
Loading…
Reference in a new issue