Fix unnecessary memory allocation while sampling
This commit is contained in:
parent
fafe47114c
commit
03e2947b03
1 changed files with 33 additions and 21 deletions
|
@ -176,6 +176,28 @@ class Llama:
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
n_vocab = self.n_vocab()
|
||||||
|
n_ctx = self.n_ctx()
|
||||||
|
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||||
|
*[
|
||||||
|
llama_cpp.llama_token_data(
|
||||||
|
id=llama_cpp.llama_token(i),
|
||||||
|
logit=llama_cpp.c_float(0.0),
|
||||||
|
p=llama_cpp.c_float(0.0),
|
||||||
|
)
|
||||||
|
for i in range(n_vocab)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
size = llama_cpp.c_size_t(n_vocab)
|
||||||
|
sorted = False
|
||||||
|
candidates = llama_cpp.llama_token_data_array(
|
||||||
|
data=data,
|
||||||
|
size=size,
|
||||||
|
sorted=sorted,
|
||||||
|
)
|
||||||
|
self._candidates = candidates
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
@ -296,8 +318,8 @@ class Llama:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
|
n_vocab = self.n_vocab()
|
||||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
n_ctx = self.n_ctx()
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
last_n_tokens_size = (
|
last_n_tokens_size = (
|
||||||
llama_cpp.c_int(n_ctx)
|
llama_cpp.c_int(n_ctx)
|
||||||
|
@ -305,24 +327,14 @@ class Llama:
|
||||||
else last_n_tokens_size
|
else last_n_tokens_size
|
||||||
)
|
)
|
||||||
logits = self.eval_logits[-1]
|
logits = self.eval_logits[-1]
|
||||||
nl_logit = logits[int(Llama.token_nl())]
|
nl_logit = logits[Llama.token_nl()]
|
||||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
candidates = self._candidates
|
||||||
*[
|
for i, logit in enumerate(logits):
|
||||||
llama_cpp.llama_token_data(
|
candidates.data[i].id = llama_cpp.llama_token(i)
|
||||||
id=llama_cpp.llama_token(i),
|
candidates.data[i].logit = llama_cpp.c_float(logit)
|
||||||
logit=logits[i],
|
candidates.data[i].p = llama_cpp.c_float(0.0)
|
||||||
p=llama_cpp.c_float(0.0),
|
candidates.sorted = llama_cpp.c_bool(False)
|
||||||
)
|
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||||
for i in range(n_vocab)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
size = llama_cpp.c_size_t(n_vocab)
|
|
||||||
sorted = False
|
|
||||||
candidates = llama_cpp.llama_token_data_array(
|
|
||||||
data=data,
|
|
||||||
size=size,
|
|
||||||
sorted=sorted,
|
|
||||||
)
|
|
||||||
llama_cpp.llama_sample_repetition_penalty(
|
llama_cpp.llama_sample_repetition_penalty(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
last_tokens_data=last_n_tokens_data,
|
last_tokens_data=last_n_tokens_data,
|
||||||
|
@ -339,7 +351,7 @@ class Llama:
|
||||||
alpha_presence=presence_penalty,
|
alpha_presence=presence_penalty,
|
||||||
)
|
)
|
||||||
if not penalize_nl:
|
if not penalize_nl:
|
||||||
candidates.data[int(Llama.token_nl())].logit = nl_logit
|
candidates.data[Llama.token_nl()].logit = nl_logit
|
||||||
if temp.value == 0.0:
|
if temp.value == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
|
Loading…
Reference in a new issue