Fix unnecessary memory allocation while sampling

This commit is contained in:
Andrei Betlen 2023-05-21 18:36:34 -04:00
parent fafe47114c
commit 03e2947b03

View file

@ -177,6 +177,28 @@ class Llama:
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
self._candidates = candidates
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string. """Tokenize a string.
@ -296,8 +318,8 @@ class Llama:
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx)) n_vocab = self.n_vocab()
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = self.n_ctx()
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = ( last_n_tokens_size = (
llama_cpp.c_int(n_ctx) llama_cpp.c_int(n_ctx)
@ -305,24 +327,14 @@ class Llama:
else last_n_tokens_size else last_n_tokens_size
) )
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
nl_logit = logits[int(Llama.token_nl())] nl_logit = logits[Llama.token_nl()]
data = (llama_cpp.llama_token_data * n_vocab)( candidates = self._candidates
*[ for i, logit in enumerate(logits):
llama_cpp.llama_token_data( candidates.data[i].id = llama_cpp.llama_token(i)
id=llama_cpp.llama_token(i), candidates.data[i].logit = llama_cpp.c_float(logit)
logit=logits[i], candidates.data[i].p = llama_cpp.c_float(0.0)
p=llama_cpp.c_float(0.0), candidates.sorted = llama_cpp.c_bool(False)
) candidates.size = llama_cpp.c_size_t(n_vocab)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
llama_cpp.llama_sample_repetition_penalty( llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx, ctx=self.ctx,
last_tokens_data=last_n_tokens_data, last_tokens_data=last_n_tokens_data,
@ -339,7 +351,7 @@ class Llama:
alpha_presence=presence_penalty, alpha_presence=presence_penalty,
) )
if not penalize_nl: if not penalize_nl:
candidates.data[int(Llama.token_nl())].logit = nl_logit candidates.data[Llama.token_nl()].logit = nl_logit
if temp.value == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,