diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4869a9d..8d726d3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -557,7 +557,7 @@ class Llama: logits[:] = ( logits_processor(self._input_ids, logits) if idx is None - else logits_processor(self._input_ids[:idx], logits) + else logits_processor(self._input_ids[:idx + 1], logits) ) sampling_params = _LlamaSamplingParams(