Refactor internal state for Llama class

This commit is contained in:
Andrei Betlen 2023-04-24 15:47:54 -04:00
parent 02cf881317
commit 86f8e5ad91

View file

@ -84,16 +84,9 @@ class Llama:
self.params.embedding = embedding self.params.embedding = embedding
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.last_n_tokens_data = deque(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
maxlen=self.last_n_tokens_size,
)
self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0 self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.n_past = 0 self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last ### saving and restoring state, this allows us to continue a completion if the last
@ -181,14 +174,8 @@ class Llama:
def reset(self): def reset(self):
"""Reset the model state.""" """Reset the model state."""
self.last_n_tokens_data.extend( self.eval_tokens.clear()
[llama_cpp.llama_token(0)] * self.last_n_tokens_size self.eval_logits.clear()
)
self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0
self.n_past = 0
self.all_logits.clear()
def eval(self, tokens: Sequence[llama_cpp.llama_token]): def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -200,32 +187,25 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
self.n_past = min(n_ctx - len(batch), self.tokens_consumed) n_past = min(n_ctx - len(batch), len(self.eval_tokens))
self.n_tokens = len(batch) n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch), tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(self.n_tokens), n_tokens=llama_cpp.c_int(n_tokens),
n_past=llama_cpp.c_int(self.n_past), n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch) self.eval_tokens.extend(batch)
self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch)
if self.params.logits_all: if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab) cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1 rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits self.eval_logits.extend(logits)
def sample( def sample(
self, self,
@ -246,10 +226,13 @@ class Llama:
The sampled token. The sampled token.
""" """
assert self.ctx is not None assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return llama_cpp.llama_sample_top_p_top_k( return llama_cpp.llama_sample_top_p_top_k(
ctx=self.ctx, ctx=self.ctx,
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*self.last_n_tokens_data *last_n_tokens_data
), ),
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size), last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
top_k=llama_cpp.c_int(top_k), top_k=llama_cpp.c_int(top_k),
@ -293,13 +276,13 @@ class Llama:
if ( if (
reset reset
and self._cache and self._cache
and len(self.tokens) > 0 and len(self.eval_tokens) > 0
and self.tokens == tokens[: len(self.tokens)] and self.eval_tokens == tokens[: len(self.eval_tokens)]
): ):
if self.verbose: if self.verbose:
print("generate cache hit", file=sys.stderr) print("generate cache hit", file=sys.stderr)
reset = False reset = False
tokens = tokens[len(self.tokens) :] tokens = tokens[len(self.eval_tokens) :]
### ###
if reset: if reset:
self.reset() self.reset()
@ -537,7 +520,7 @@ class Llama:
] ]
all_logprobs = [ all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row] [Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits for row in self.eval_logits
] ]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs all_tokens, all_token_strs, all_logprobs