Remove unnecessary ffi calls
This commit is contained in:
parent
e5d596e0e9
commit
fab064ded9
1 changed files with 18 additions and 20 deletions
|
@ -177,19 +177,19 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
n_vocab = self.n_vocab()
|
self._n_vocab = self.n_vocab()
|
||||||
n_ctx = self.n_ctx()
|
self._n_ctx = self.n_ctx()
|
||||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
data = (llama_cpp.llama_token_data * self._n_vocab)(
|
||||||
*[
|
*[
|
||||||
llama_cpp.llama_token_data(
|
llama_cpp.llama_token_data(
|
||||||
id=llama_cpp.llama_token(i),
|
id=llama_cpp.llama_token(i),
|
||||||
logit=llama_cpp.c_float(0.0),
|
logit=llama_cpp.c_float(0.0),
|
||||||
p=llama_cpp.c_float(0.0),
|
p=llama_cpp.c_float(0.0),
|
||||||
)
|
)
|
||||||
for i in range(n_vocab)
|
for i in range(self._n_vocab)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
size = llama_cpp.c_size_t(n_vocab)
|
size = llama_cpp.c_size_t(self._n_vocab)
|
||||||
sorted = False
|
sorted = False
|
||||||
candidates = llama_cpp.llama_token_data_array(
|
candidates = llama_cpp.llama_token_data_array(
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -213,8 +213,8 @@ class Llama:
|
||||||
A list of tokens.
|
A list of tokens.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
n_ctx = llama_cpp.llama_n_ctx(self.ctx)
|
n_ctx = self._n_ctx
|
||||||
tokens = (llama_cpp.llama_token * int(n_ctx))()
|
tokens = (llama_cpp.llama_token * n_ctx)()
|
||||||
n_tokens = llama_cpp.llama_tokenize(
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
text,
|
text,
|
||||||
|
@ -222,9 +222,9 @@ class Llama:
|
||||||
llama_cpp.c_int(n_ctx),
|
llama_cpp.c_int(n_ctx),
|
||||||
llama_cpp.c_bool(add_bos),
|
llama_cpp.c_bool(add_bos),
|
||||||
)
|
)
|
||||||
if int(n_tokens) < 0:
|
if n_tokens < 0:
|
||||||
n_tokens = abs(n_tokens)
|
n_tokens = abs(n_tokens)
|
||||||
tokens = (llama_cpp.llama_token * int(n_tokens))()
|
tokens = (llama_cpp.llama_token * n_tokens)()
|
||||||
n_tokens = llama_cpp.llama_tokenize(
|
n_tokens = llama_cpp.llama_tokenize(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
text,
|
text,
|
||||||
|
@ -275,7 +275,7 @@ class Llama:
|
||||||
tokens: The list of tokens to evaluate.
|
tokens: The list of tokens to evaluate.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
n_ctx = self._n_ctx
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
||||||
|
@ -287,18 +287,16 @@ class Llama:
|
||||||
n_past=llama_cpp.c_int(n_past),
|
n_past=llama_cpp.c_int(n_past),
|
||||||
n_threads=llama_cpp.c_int(self.n_threads),
|
n_threads=llama_cpp.c_int(self.n_threads),
|
||||||
)
|
)
|
||||||
if int(return_code) != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
# Save tokens
|
# Save tokens
|
||||||
self.eval_tokens.extend(batch)
|
self.eval_tokens.extend(batch)
|
||||||
# Save logits
|
# Save logits
|
||||||
rows = n_tokens if self.params.logits_all else 1
|
rows = n_tokens if self.params.logits_all else 1
|
||||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
n_vocab = self._n_vocab
|
||||||
cols = int(n_vocab)
|
cols = n_vocab
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits: List[List[float]] = [
|
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||||
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
|
||||||
]
|
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
|
@ -319,8 +317,8 @@ class Llama:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
n_vocab = self.n_vocab()
|
n_vocab = self._n_vocab
|
||||||
n_ctx = self.n_ctx()
|
n_ctx = self._n_ctx
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
last_n_tokens_size = (
|
last_n_tokens_size = (
|
||||||
llama_cpp.c_int(n_ctx)
|
llama_cpp.c_int(n_ctx)
|
||||||
|
@ -654,9 +652,9 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
if len(prompt_tokens) + max_tokens > self._n_ctx:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
f"Requested tokens exceed context window of {self._n_ctx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop != []:
|
if stop != []:
|
||||||
|
|
Loading…
Reference in a new issue