Update sampling api
This commit is contained in:
parent
7837c3fdc7
commit
350a1769e1
2 changed files with 113 additions and 28 deletions
|
@ -127,7 +127,9 @@ class Llama:
|
|||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
|
||||
maxlen=n_ctx if logits_all else 1
|
||||
)
|
||||
|
||||
self.cache: Optional[LlamaCache] = None
|
||||
|
||||
|
@ -236,17 +238,90 @@ class Llama:
|
|||
)
|
||||
if int(return_code) != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
# Save tokens
|
||||
self.eval_tokens.extend(batch)
|
||||
if self.params.logits_all:
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
rows = n_tokens
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [
|
||||
[logits_view[i * cols + j] for j in range(cols)]
|
||||
for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
# Save logits
|
||||
rows = n_tokens if self.params.logits_all else 1
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits: List[List[llama_cpp.c_float]] = [
|
||||
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
|
||||
def _sample_top_p_top_k(
|
||||
self,
|
||||
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
||||
last_n_tokens_size: llama_cpp.c_int,
|
||||
top_k: llama_cpp.c_int,
|
||||
top_p: llama_cpp.c_float,
|
||||
temp: llama_cpp.c_float,
|
||||
repeat_penalty: llama_cpp.c_float,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
|
||||
logits = self.eval_logits[-1]
|
||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||
*[
|
||||
llama_cpp.llama_token_data(
|
||||
id=llama_cpp.llama_token(i),
|
||||
logit=logits[i],
|
||||
p=llama_cpp.c_float(0.0),
|
||||
)
|
||||
for i in range(n_vocab)
|
||||
]
|
||||
)
|
||||
size = llama_cpp.c_size_t(n_vocab)
|
||||
sorted = False
|
||||
candidates = llama_cpp.llama_token_data_array(
|
||||
data=data,
|
||||
size=size,
|
||||
sorted=sorted,
|
||||
)
|
||||
llama_cpp.llama_sample_repetition_penalty(
|
||||
ctx=self.ctx,
|
||||
last_tokens_data=last_n_tokens_data,
|
||||
last_tokens_size=last_n_tokens_size,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
penalty=repeat_penalty,
|
||||
)
|
||||
if temp == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
)
|
||||
else:
|
||||
llama_cpp.llama_sample_top_k(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
k=top_k,
|
||||
)
|
||||
llama_cpp.llama_sample_tail_free(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
z=llama_cpp.c_float(1.0),
|
||||
)
|
||||
llama_cpp.llama_sample_typical(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
p=llama_cpp.c_float(1.0)
|
||||
)
|
||||
llama_cpp.llama_sample_top_p(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
p=top_p,
|
||||
)
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
temp=temp,
|
||||
)
|
||||
return llama_cpp.llama_sample_token(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
@ -270,8 +345,7 @@ class Llama:
|
|||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||
return llama_cpp.llama_sample_top_p_top_k(
|
||||
ctx=self.ctx,
|
||||
return self._sample_top_p_top_k(
|
||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
),
|
||||
|
@ -470,15 +544,15 @@ class Llama:
|
|||
all_text = self.detokenize(completion_tokens)
|
||||
|
||||
# Contains multi-byte UTF8
|
||||
for k,char in enumerate(all_text[-3:]):
|
||||
for k, char in enumerate(all_text[-3:]):
|
||||
k = 3 - k
|
||||
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||
# Bitwise AND check
|
||||
if (num > k and pattern & char == pattern):
|
||||
if num > k and pattern & char == pattern:
|
||||
multibyte_fix = num - k
|
||||
|
||||
# Stop incomplete bytes from passing
|
||||
if (multibyte_fix > 0):
|
||||
if multibyte_fix > 0:
|
||||
multibyte_fix -= 1
|
||||
continue
|
||||
|
||||
|
@ -531,7 +605,9 @@ class Llama:
|
|||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
|
||||
"text": text[returned_characters:].decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
|
@ -558,7 +634,8 @@ class Llama:
|
|||
|
||||
all_tokens = prompt_tokens + completion_tokens
|
||||
all_token_strs = [
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
|
@ -577,7 +654,9 @@ class Llama:
|
|||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
|
|
|
@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
|
|||
|
||||
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||
|
||||
|
||||
|
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_int,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_top_k.restype = None
|
||||
|
||||
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_top_p.restype = None
|
||||
|
||||
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||
|
||||
|
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_tail_free.restype = None
|
||||
|
||||
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue