Update sampling api

This commit is contained in:
Andrei Betlen 2023-05-01 14:47:55 -04:00
parent 7837c3fdc7
commit 350a1769e1
2 changed files with 113 additions and 28 deletions

View file

@ -127,7 +127,9 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
maxlen=n_ctx if logits_all else 1
)
self.cache: Optional[LlamaCache] = None
@ -236,17 +238,90 @@ class Llama:
)
if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.eval_tokens.extend(batch)
if self.params.logits_all:
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [
[logits_view[i * cols + j] for j in range(cols)]
for i in range(rows)
]
self.eval_logits.extend(logits)
# Save logits
rows = n_tokens if self.params.logits_all else 1
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits: List[List[llama_cpp.c_float]] = [
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
]
self.eval_logits.extend(logits)
def _sample_top_p_top_k(
self,
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
last_n_tokens_size: llama_cpp.c_int,
top_k: llama_cpp.c_int,
top_p: llama_cpp.c_float,
temp: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
logits = self.eval_logits[-1]
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=logits[i],
p=llama_cpp.c_float(0.0),
)
for i in range(n_vocab)
]
)
size = llama_cpp.c_size_t(n_vocab)
sorted = False
candidates = llama_cpp.llama_token_data_array(
data=data,
size=size,
sorted=sorted,
)
llama_cpp.llama_sample_repetition_penalty(
ctx=self.ctx,
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
candidates=llama_cpp.ctypes.pointer(candidates),
penalty=repeat_penalty,
)
if temp == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
)
else:
llama_cpp.llama_sample_top_k(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
k=top_k,
)
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
z=llama_cpp.c_float(1.0),
)
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=llama_cpp.c_float(1.0)
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=top_p,
)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
return llama_cpp.llama_sample_token(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
)
def sample(
self,
@ -270,8 +345,7 @@ class Llama:
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return llama_cpp.llama_sample_top_p_top_k(
ctx=self.ctx,
return self._sample_top_p_top_k(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
),
@ -470,15 +544,15 @@ class Llama:
all_text = self.detokenize(completion_tokens)
# Contains multi-byte UTF8
for k,char in enumerate(all_text[-3:]):
for k, char in enumerate(all_text[-3:]):
k = 3 - k
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if (num > k and pattern & char == pattern):
if num > k and pattern & char == pattern:
multibyte_fix = num - k
# Stop incomplete bytes from passing
if (multibyte_fix > 0):
if multibyte_fix > 0:
multibyte_fix -= 1
continue
@ -531,7 +605,9 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
"text": text[returned_characters:].decode(
"utf-8", errors="ignore"
),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
@ -558,7 +634,8 @@ class Llama:
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
@ -577,7 +654,9 @@ class Llama:
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})

View file

@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
def llama_sample_top_k(
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_int,
c_int,
c_size_t,
]
_lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
def llama_sample_top_p(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
def llama_sample_tail_free(
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
def llama_sample_typical(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_typical.restype = None