Fix unnecessary memory allocation while sampling
This commit is contained in:
parent
fafe47114c
commit
03e2947b03
1 changed files with 33 additions and 21 deletions
|
@ -177,6 +177,28 @@ class Llama:
|
|||
if self.verbose:
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||
|
||||
|
||||
n_vocab = self.n_vocab()
|
||||
n_ctx = self.n_ctx()
|
||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||
*[
|
||||
llama_cpp.llama_token_data(
|
||||
id=llama_cpp.llama_token(i),
|
||||
logit=llama_cpp.c_float(0.0),
|
||||
p=llama_cpp.c_float(0.0),
|
||||
)
|
||||
for i in range(n_vocab)
|
||||
]
|
||||
)
|
||||
size = llama_cpp.c_size_t(n_vocab)
|
||||
sorted = False
|
||||
candidates = llama_cpp.llama_token_data_array(
|
||||
data=data,
|
||||
size=size,
|
||||
sorted=sorted,
|
||||
)
|
||||
self._candidates = candidates
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
||||
|
@ -296,8 +318,8 @@ class Llama:
|
|||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
|
||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||
n_vocab = self.n_vocab()
|
||||
n_ctx = self.n_ctx()
|
||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
last_n_tokens_size = (
|
||||
llama_cpp.c_int(n_ctx)
|
||||
|
@ -305,24 +327,14 @@ class Llama:
|
|||
else last_n_tokens_size
|
||||
)
|
||||
logits = self.eval_logits[-1]
|
||||
nl_logit = logits[int(Llama.token_nl())]
|
||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||
*[
|
||||
llama_cpp.llama_token_data(
|
||||
id=llama_cpp.llama_token(i),
|
||||
logit=logits[i],
|
||||
p=llama_cpp.c_float(0.0),
|
||||
)
|
||||
for i in range(n_vocab)
|
||||
]
|
||||
)
|
||||
size = llama_cpp.c_size_t(n_vocab)
|
||||
sorted = False
|
||||
candidates = llama_cpp.llama_token_data_array(
|
||||
data=data,
|
||||
size=size,
|
||||
sorted=sorted,
|
||||
)
|
||||
nl_logit = logits[Llama.token_nl()]
|
||||
candidates = self._candidates
|
||||
for i, logit in enumerate(logits):
|
||||
candidates.data[i].id = llama_cpp.llama_token(i)
|
||||
candidates.data[i].logit = llama_cpp.c_float(logit)
|
||||
candidates.data[i].p = llama_cpp.c_float(0.0)
|
||||
candidates.sorted = llama_cpp.c_bool(False)
|
||||
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||
llama_cpp.llama_sample_repetition_penalty(
|
||||
ctx=self.ctx,
|
||||
last_tokens_data=last_n_tokens_data,
|
||||
|
@ -339,7 +351,7 @@ class Llama:
|
|||
alpha_presence=presence_penalty,
|
||||
)
|
||||
if not penalize_nl:
|
||||
candidates.data[int(Llama.token_nl())].logit = nl_logit
|
||||
candidates.data[Llama.token_nl()].logit = nl_logit
|
||||
if temp.value == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
|
|
Loading…
Reference in a new issue