Fix eval logits type
This commit is contained in:
parent
b5f3e74627
commit
98bbd1c6a8
1 changed files with 3 additions and 3 deletions
|
@ -127,7 +127,7 @@ class Llama:
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||||
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
|
self.eval_logits: Deque[List[float]] = deque(
|
||||||
maxlen=n_ctx if logits_all else 1
|
maxlen=n_ctx if logits_all else 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ class Llama:
|
||||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||||
cols = int(n_vocab)
|
cols = int(n_vocab)
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits: List[List[llama_cpp.c_float]] = [
|
logits: List[List[float]] = [
|
||||||
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
||||||
]
|
]
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
@ -287,7 +287,7 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
penalty=repeat_penalty,
|
penalty=repeat_penalty,
|
||||||
)
|
)
|
||||||
if float(temp) == 0.0:
|
if float(temp.value) == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
|
Loading…
Reference in a new issue