Potential bugfix for eval
This commit is contained in:
parent
52350cc9d7
commit
d9b38e3e3a
1 changed files with 2 additions and 3 deletions
|
@ -1019,12 +1019,11 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
assert self._batch.batch is not None
|
assert self._batch.batch is not None
|
||||||
n_ctx = self._n_ctx
|
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), self.n_tokens)
|
n_past = self.n_tokens
|
||||||
n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
self._ctx.kv_cache_seq_rm(-1, n_past, -1)
|
|
||||||
self._batch.set_batch(
|
self._batch.set_batch(
|
||||||
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue