Track generated tokens internally
This commit is contained in:
parent
25b646c2fb
commit
ac7068a469
1 changed files with 3 additions and 0 deletions
|
@ -76,6 +76,7 @@ class Llama:
|
||||||
maxlen=self.last_n_tokens_size,
|
maxlen=self.last_n_tokens_size,
|
||||||
)
|
)
|
||||||
self.tokens_consumed = 0
|
self.tokens_consumed = 0
|
||||||
|
self.tokens: List[llama_cpp.llama_token] = []
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.n_tokens = 0
|
self.n_tokens = 0
|
||||||
self.n_past = 0
|
self.n_past = 0
|
||||||
|
@ -140,6 +141,7 @@ class Llama:
|
||||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||||
)
|
)
|
||||||
self.tokens_consumed = 0
|
self.tokens_consumed = 0
|
||||||
|
self.tokens.clear()
|
||||||
self.n_tokens = 0
|
self.n_tokens = 0
|
||||||
self.n_past = 0
|
self.n_past = 0
|
||||||
self.all_logits = []
|
self.all_logits = []
|
||||||
|
@ -165,6 +167,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
if int(return_code) != 0:
|
if int(return_code) != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
|
self.tokens.extend(batch)
|
||||||
self.last_n_tokens_data.extend(batch)
|
self.last_n_tokens_data.extend(batch)
|
||||||
self.tokens_consumed += len(batch)
|
self.tokens_consumed += len(batch)
|
||||||
if self.params.logits_all:
|
if self.params.logits_all:
|
||||||
|
|
Loading…
Reference in a new issue