Track generated tokens internally

This commit is contained in:
Andrei Betlen 2023-04-14 23:33:00 -04:00
parent 25b646c2fb
commit ac7068a469

View file

@ -76,6 +76,7 @@ class Llama:
maxlen=self.last_n_tokens_size, maxlen=self.last_n_tokens_size,
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0 self.n_tokens = 0
self.n_past = 0 self.n_past = 0
@ -140,6 +141,7 @@ class Llama:
[llama_cpp.llama_token(0)] * self.last_n_tokens_size [llama_cpp.llama_token(0)] * self.last_n_tokens_size
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0 self.n_tokens = 0
self.n_past = 0 self.n_past = 0
self.all_logits = [] self.all_logits = []
@ -165,6 +167,7 @@ class Llama:
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch)
self.last_n_tokens_data.extend(batch) self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch) self.tokens_consumed += len(batch)
if self.params.logits_all: if self.params.logits_all: