Update sampling api

This commit is contained in:
Andrei Betlen 2023-05-01 14:47:55 -04:00
parent 7837c3fdc7
commit 350a1769e1
2 changed files with 113 additions and 28 deletions

View file

@ -127,7 +127,9 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx) self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
maxlen=n_ctx if logits_all else 1
)
self.cache: Optional[LlamaCache] = None self.cache: Optional[LlamaCache] = None
@ -236,18 +238,91 @@ class Llama:
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.eval_tokens.extend(batch) self.eval_tokens.extend(batch)
if self.params.logits_all: # Save logits
rows = n_tokens if self.params.logits_all else 1
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab) cols = int(n_vocab)
rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [ logits: List[List[llama_cpp.c_float]] = [
[logits_view[i * cols + j] for j in range(cols)] [logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
for i in range(rows)
] ]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
def _sample_top_p_top_k(
self,
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
last_n_tokens_size: llama_cpp.c_int,
top_k: llama_cpp.c_int,
top_p: llama_cpp.c_float,
temp: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
logits = self.eval_logits[-1]
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=logits[i],
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
candidates=llama_cpp.ctypes.pointer(candidates),
penalty=repeat_penalty,
)
if temp == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
)
else:
llama_cpp.llama_sample_top_k(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
k=top_k,
)
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
z=llama_cpp.c_float(1.0),
)
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=llama_cpp.c_float(1.0)
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=top_p,
)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
return llama_cpp.llama_sample_token(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
)
def sample( def sample(
self, self,
top_k: int, top_k: int,
@ -270,8 +345,7 @@ class Llama:
last_n_tokens_data = [llama_cpp.llama_token(0)] * max( last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens) 0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :] ) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return llama_cpp.llama_sample_top_p_top_k( return self._sample_top_p_top_k(
ctx=self.ctx,
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data *last_n_tokens_data
), ),
@ -470,15 +544,15 @@ class Llama:
all_text = self.detokenize(completion_tokens) all_text = self.detokenize(completion_tokens)
# Contains multi-byte UTF8 # Contains multi-byte UTF8
for k,char in enumerate(all_text[-3:]): for k, char in enumerate(all_text[-3:]):
k = 3 - k k = 3 - k
for num,pattern in [(2, 192), (3, 224), (4, 240)]: for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check # Bitwise AND check
if (num > k and pattern & char == pattern): if num > k and pattern & char == pattern:
multibyte_fix = num - k multibyte_fix = num - k
# Stop incomplete bytes from passing # Stop incomplete bytes from passing
if (multibyte_fix > 0): if multibyte_fix > 0:
multibyte_fix -= 1 multibyte_fix -= 1
continue continue
@ -531,7 +605,9 @@ class Llama:
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {
"text": text[returned_characters:].decode("utf-8", errors="ignore"), "text": text[returned_characters:].decode(
"utf-8", errors="ignore"
),
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": finish_reason, "finish_reason": finish_reason,
@ -558,7 +634,8 @@ class Llama:
all_tokens = prompt_tokens + completion_tokens all_tokens = prompt_tokens + completion_tokens
all_token_strs = [ all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
] ]
all_logprobs = [ all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row] [Llama.logit_to_logprob(logit) for logit in row]
@ -577,7 +654,9 @@ class Llama:
) )
token_logprobs.append(sorted_logprobs[int(token)][0]) token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = { top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs] for logprob, i in sorted_logprobs[:logprobs]
} }
top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) top_logprob.update({token_str: sorted_logprobs[int(token)][0]})

View file

@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int): def llama_sample_top_k(
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep) return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
llama_context_p, llama_context_p,
llama_token_data_array_p, llama_token_data_array_p,
c_int, c_int,
c_int, c_size_t,
] ]
_lib.llama_sample_top_k.restype = None _lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): def llama_sample_top_p(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep) return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
llama_context_p, llama_context_p,
llama_token_data_array_p, llama_token_data_array_p,
c_float, c_float,
c_int, c_size_t,
] ]
_lib.llama_sample_top_p.restype = None _lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. # @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
def llama_sample_tail_free( def llama_sample_tail_free(
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
): ):
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep) return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
llama_context_p, llama_context_p,
llama_token_data_array_p, llama_token_data_array_p,
c_float, c_float,
c_int, c_size_t,
] ]
_lib.llama_sample_tail_free.restype = None _lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. # @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): def llama_sample_typical(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_typical(ctx, candidates, p, min_keep) return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
llama_context_p, llama_context_p,
llama_token_data_array_p, llama_token_data_array_p,
c_float, c_float,
c_int, c_size_t,
] ]
_lib.llama_sample_typical.restype = None _lib.llama_sample_typical.restype = None