Remove unnecessary ffi calls

This commit is contained in:
Andrei Betlen 2023-05-23 17:56:21 -04:00
parent e5d596e0e9
commit fab064ded9

View file

@ -177,19 +177,19 @@ class Llama:
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
n_vocab = self.n_vocab() self._n_vocab = self.n_vocab()
n_ctx = self.n_ctx() self._n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * n_vocab)( data = (llama_cpp.llama_token_data * self._n_vocab)(
*[ *[
llama_cpp.llama_token_data( llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i), id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0), logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0), p=llama_cpp.c_float(0.0),
) )
for i in range(n_vocab) for i in range(self._n_vocab)
] ]
) )
size = llama_cpp.c_size_t(n_vocab) size = llama_cpp.c_size_t(self._n_vocab)
sorted = False sorted = False
candidates = llama_cpp.llama_token_data_array( candidates = llama_cpp.llama_token_data_array(
data=data, data=data,
@ -213,8 +213,8 @@ class Llama:
A list of tokens. A list of tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
n_ctx = llama_cpp.llama_n_ctx(self.ctx) n_ctx = self._n_ctx
tokens = (llama_cpp.llama_token * int(n_ctx))() tokens = (llama_cpp.llama_token * n_ctx)()
n_tokens = llama_cpp.llama_tokenize( n_tokens = llama_cpp.llama_tokenize(
self.ctx, self.ctx,
text, text,
@ -222,9 +222,9 @@ class Llama:
llama_cpp.c_int(n_ctx), llama_cpp.c_int(n_ctx),
llama_cpp.c_bool(add_bos), llama_cpp.c_bool(add_bos),
) )
if int(n_tokens) < 0: if n_tokens < 0:
n_tokens = abs(n_tokens) n_tokens = abs(n_tokens)
tokens = (llama_cpp.llama_token * int(n_tokens))() tokens = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize( n_tokens = llama_cpp.llama_tokenize(
self.ctx, self.ctx,
text, text,
@ -275,7 +275,7 @@ class Llama:
tokens: The list of tokens to evaluate. tokens: The list of tokens to evaluate.
""" """
assert self.ctx is not None assert self.ctx is not None
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self.eval_tokens)) n_past = min(n_ctx - len(batch), len(self.eval_tokens))
@ -287,18 +287,16 @@ class Llama:
n_past=llama_cpp.c_int(n_past), n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if return_code != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens # Save tokens
self.eval_tokens.extend(batch) self.eval_tokens.extend(batch)
# Save logits # Save logits
rows = n_tokens if self.params.logits_all else 1 rows = n_tokens if self.params.logits_all else 1
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = self._n_vocab
cols = int(n_vocab) cols = n_vocab
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits: List[List[float]] = [ logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
def _sample( def _sample(
@ -319,8 +317,8 @@ class Llama:
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
n_vocab = self.n_vocab() n_vocab = self._n_vocab
n_ctx = self.n_ctx() n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = ( last_n_tokens_size = (
llama_cpp.c_int(n_ctx) llama_cpp.c_int(n_ctx)
@ -654,9 +652,9 @@ class Llama:
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): if len(prompt_tokens) + max_tokens > self._n_ctx:
raise ValueError( raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" f"Requested tokens exceed context window of {self._n_ctx}"
) )
if stop != []: if stop != []: