Fix unnecessary memory allocation while sampling

This commit is contained in:
Andrei Betlen 2023-05-21 18:36:34 -04:00
parent fafe47114c
commit 03e2947b03

View file

@ -176,6 +176,28 @@ class Llama:
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
self._candidates = candidates
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
@ -296,8 +318,8 @@ class Llama:
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = (
llama_cpp.c_int(n_ctx)
@ -305,24 +327,14 @@ class Llama:
else last_n_tokens_size
)
logits = self.eval_logits[-1]
nl_logit = logits[int(Llama.token_nl())]
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=logits[i],
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
nl_logit = logits[Llama.token_nl()]
candidates = self._candidates
for i, logit in enumerate(logits):
candidates.data[i].id = llama_cpp.llama_token(i)
candidates.data[i].logit = llama_cpp.c_float(logit)
candidates.data[i].p = llama_cpp.c_float(0.0)
candidates.sorted = llama_cpp.c_bool(False)
candidates.size = llama_cpp.c_size_t(n_vocab)
llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
@ -339,7 +351,7 @@ class Llama:
alpha_presence=presence_penalty,
)
if not penalize_nl:
candidates.data[int(Llama.token_nl())].logit = nl_logit
candidates.data[Llama.token_nl()].logit = nl_logit
if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,