Fix eval logits type
This commit is contained in:
parent
b5f3e74627
commit
98bbd1c6a8
1 changed files with 3 additions and 3 deletions
|
@ -127,7 +127,7 @@ class Llama:
|
|||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
|
||||
self.eval_logits: Deque[List[float]] = deque(
|
||||
maxlen=n_ctx if logits_all else 1
|
||||
)
|
||||
|
||||
|
@ -245,7 +245,7 @@ class Llama:
|
|||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits: List[List[llama_cpp.c_float]] = [
|
||||
logits: List[List[float]] = [
|
||||
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
|
@ -287,7 +287,7 @@ class Llama:
|
|||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
penalty=repeat_penalty,
|
||||
)
|
||||
if float(temp) == 0.0:
|
||||
if float(temp.value) == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
|
|
Loading…
Reference in a new issue