diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a643f51..fc91ea4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -51,7 +51,7 @@ class LlamaState: def __init__( self, eval_tokens: Deque[llama_cpp.llama_token], - eval_logits: Deque[List[llama_cpp.c_float]], + eval_logits: Deque[List[float]], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state_size: llama_cpp.c_size_t, ):