From b3805bb9ccc2a33d68b568cd00e10f89a0f9506b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 12 Apr 2023 14:05:11 -0400 Subject: [PATCH] Implement logprobs parameter for text completion. Closes #2 --- llama_cpp/llama.py | 125 ++++++++++++++++++++++++++++++----- llama_cpp/server/__main__.py | 2 + 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2d76ec4..3e13776 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -2,6 +2,7 @@ import os import sys import uuid import time +import math import multiprocessing from typing import List, Optional, Union, Generator, Sequence, Iterator from collections import deque @@ -76,6 +77,9 @@ class Llama: ) self.tokens_consumed = 0 self.n_batch = min(n_ctx, n_batch) + self.n_tokens = 0 + self.n_past = 0 + self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list. self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) @@ -136,6 +140,9 @@ class Llama: [llama_cpp.llama_token(0)] * self.last_n_tokens_size ) self.tokens_consumed = 0 + self.n_tokens = 0 + self.n_past = 0 + self.all_logits = [] def eval(self, tokens: Sequence[llama_cpp.llama_token]): """Evaluate a list of tokens. @@ -147,18 +154,31 @@ class Llama: n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), self.tokens_consumed) + self.n_past = min(n_ctx - len(batch), self.tokens_consumed) + self.n_tokens = len(batch) return_code = llama_cpp.llama_eval( ctx=self.ctx, tokens=(llama_cpp.llama_token * len(batch))(*batch), - n_tokens=llama_cpp.c_int(len(batch)), - n_past=llama_cpp.c_int(n_past), + n_tokens=llama_cpp.c_int(self.n_tokens), + n_past=llama_cpp.c_int(self.n_past), n_threads=llama_cpp.c_int(self.n_threads), ) if int(return_code) != 0: raise RuntimeError(f"llama_eval returned {return_code}") self.last_n_tokens_data.extend(batch) self.tokens_consumed += len(batch) + if self.params.logits_all: + self.all_logits.extend(self._logits()) + + def _logits(self) -> List[List[float]]: + """Return the logits from the last call to llama_eval.""" + assert self.ctx is not None + n_vocab = llama_cpp.llama_n_vocab(self.ctx) + cols = int(n_vocab) + rows = self.n_tokens if self.params.logits_all else 1 + logits_view = llama_cpp.llama_get_logits(self.ctx) + logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] + return logits def sample( self, @@ -327,14 +347,55 @@ class Llama: else: stop_sequences = [] - finish_reason = None - for token in self.generate( - prompt_tokens, - top_k=top_k, - top_p=top_p, - temp=temperature, - repeat_penalty=repeat_penalty, - ): + text_offset = 0 + text_offsets: List[int] = [] + token_logprobs: List[float] = [] + tokens: List[str] = [] + top_logprobs: List[Dict[str, float]] = [] + + self.reset() + self.eval(prompt_tokens) + + if logprobs is not None and self.params.logits_all is False: + raise ValueError( + "logprobs is not supported for models created with logits_all=False" + ) + + if logprobs is not None: + token_strs = [ + self.detokenize([token]).decode("utf-8") for token in prompt_tokens + ] + logprobs_all = [ + [Llama.logit_to_logprob(logit) for logit in row] + for row in self.all_logits + ] + for token, token_str, logprobs_token in zip( + prompt_tokens, token_strs, logprobs_all + ): + text_offsets.append(text_offset) + text_offset += len(token_str) + tokens.append(token_str) + sorted_logprobs = list( + sorted( + zip(logprobs_token, range(len(logprobs_token))), reverse=True + ) + ) + token_logprobs.append(sorted_logprobs[int(token)][0]) + top_logprob = { + self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) + top_logprobs.append(top_logprob) + + finish_reason = "length" + while True: + token = self.sample( + top_k=top_k, + top_p=top_p, + temp=temperature, + repeat_penalty=repeat_penalty, + ) if token == llama_cpp.llama_token_eos(): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -377,13 +438,35 @@ class Llama: } ], } + + if logprobs is not None: + # TODO: Confirm wether this should happen before or after + # next eval. + token_str = self.detokenize([token]).decode("utf-8") + text_offsets.append(text_offset) + text_offset += len(token_str) + tokens.append(token_str) + logprobs_token = [ + Llama.logit_to_logprob(logit) for logit in self.all_logits[-1] + ] + sorted_logprobs = list( + sorted( + zip(logprobs_token, range(len(logprobs_token))), reverse=True + ) + ) + token_logprobs.append(sorted_logprobs[int(token)][0]) + top_logprob = { + self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob + for logprob, i in sorted_logprobs[:logprobs] + } + top_logprob.update({token_str: logprobs_token[int(token)]}) + top_logprobs.append(top_logprob) + if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) finish_reason = "length" break - - if finish_reason is None: - finish_reason = "length" + self.eval([token]) if stream: yield { @@ -410,8 +493,14 @@ class Llama: if suffix is not None: text = text + suffix + logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: - raise NotImplementedError("logprobs not implemented") + logprobs_or_none = { + "tokens": tokens, + "text_offset": text_offsets, + "token_logprobs": token_logprobs, + "top_logprobs": top_logprobs, + } if self.verbose: llama_cpp.llama_print_timings(self.ctx) @@ -425,7 +514,7 @@ class Llama: { "text": text, "index": 0, - "logprobs": None, + "logprobs": logprobs_or_none, "finish_reason": finish_reason, } ], @@ -704,3 +793,7 @@ class Llama: def token_bos() -> llama_cpp.llama_token: """Return the beginning-of-sequence token.""" return llama_cpp.llama_token_bos() + + @staticmethod + def logit_to_logprob(x: float) -> float: + return math.log(1.0 + math.exp(x)) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 80cbe01..49a00b2 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -33,6 +33,7 @@ class Settings(BaseSettings): use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... embedding: bool = True last_n_tokens_size: int = 64 + logits_all: bool = False app = FastAPI( @@ -52,6 +53,7 @@ llama = llama_cpp.Llama( f16_kv=settings.f16_kv, use_mlock=settings.use_mlock, embedding=settings.embedding, + logits_all=settings.logits_all, n_threads=settings.n_threads, n_batch=settings.n_batch, n_ctx=settings.n_ctx,