Fix eval logits type

This commit is contained in:
Andrei Betlen 2023-05-05 14:23:14 -04:00
parent b5f3e74627
commit 98bbd1c6a8

View file

@ -127,7 +127,7 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque( self.eval_logits: Deque[List[float]] = deque(
maxlen=n_ctx if logits_all else 1 maxlen=n_ctx if logits_all else 1
) )
@ -245,7 +245,7 @@ class Llama:
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab) cols = int(n_vocab)
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits: List[List[llama_cpp.c_float]] = [ logits: List[List[float]] = [
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows) [logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
] ]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
@ -287,7 +287,7 @@ class Llama:
candidates=llama_cpp.ctypes.pointer(candidates), candidates=llama_cpp.ctypes.pointer(candidates),
penalty=repeat_penalty, penalty=repeat_penalty,
) )
if float(temp) == 0.0: if float(temp.value) == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates), candidates=llama_cpp.ctypes.pointer(candidates),