diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 58e407f..e6d150d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -44,6 +44,7 @@ class Llama: max_tokens: int = 16, temperature: float = 0.8, top_p: float = 0.95, + logprobs: Optional[int] = None, echo: bool = False, stop: List[str] = [], repeat_penalty: float = 1.1, @@ -105,6 +106,11 @@ class Llama: if suffix is not None: text = text + suffix + if logprobs is not None: + logprobs = llama_cpp.llama_get_logits( + self.ctx, + )[:logprobs] + return { "id": f"cmpl-{str(uuid.uuid4())}", # Likely to change "object": "text_completion", @@ -114,7 +120,7 @@ class Llama: { "text": text, "index": 0, - "logprobs": None, + "logprobs": logprobs, "finish_reason": finish_reason, } ],