Update sampling api
This commit is contained in:
parent
7837c3fdc7
commit
350a1769e1
2 changed files with 113 additions and 28 deletions
|
@ -127,7 +127,9 @@ class Llama:
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
|
||||||
|
maxlen=n_ctx if logits_all else 1
|
||||||
|
)
|
||||||
|
|
||||||
self.cache: Optional[LlamaCache] = None
|
self.cache: Optional[LlamaCache] = None
|
||||||
|
|
||||||
|
@ -236,18 +238,91 @@ class Llama:
|
||||||
)
|
)
|
||||||
if int(return_code) != 0:
|
if int(return_code) != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
|
# Save tokens
|
||||||
self.eval_tokens.extend(batch)
|
self.eval_tokens.extend(batch)
|
||||||
if self.params.logits_all:
|
# Save logits
|
||||||
|
rows = n_tokens if self.params.logits_all else 1
|
||||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||||
cols = int(n_vocab)
|
cols = int(n_vocab)
|
||||||
rows = n_tokens
|
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits = [
|
logits: List[List[llama_cpp.c_float]] = [
|
||||||
[logits_view[i * cols + j] for j in range(cols)]
|
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
|
||||||
for i in range(rows)
|
|
||||||
]
|
]
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
|
||||||
|
def _sample_top_p_top_k(
|
||||||
|
self,
|
||||||
|
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
||||||
|
last_n_tokens_size: llama_cpp.c_int,
|
||||||
|
top_k: llama_cpp.c_int,
|
||||||
|
top_p: llama_cpp.c_float,
|
||||||
|
temp: llama_cpp.c_float,
|
||||||
|
repeat_penalty: llama_cpp.c_float,
|
||||||
|
):
|
||||||
|
assert self.ctx is not None
|
||||||
|
assert len(self.eval_logits) > 0
|
||||||
|
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
|
||||||
|
logits = self.eval_logits[-1]
|
||||||
|
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||||
|
*[
|
||||||
|
llama_cpp.llama_token_data(
|
||||||
|
id=llama_cpp.llama_token(i),
|
||||||
|
logit=logits[i],
|
||||||
|
p=llama_cpp.c_float(0.0),
|
||||||
|
)
|
||||||
|
for i in range(n_vocab)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
size = llama_cpp.c_size_t(n_vocab)
|
||||||
|
sorted = False
|
||||||
|
candidates = llama_cpp.llama_token_data_array(
|
||||||
|
data=data,
|
||||||
|
size=size,
|
||||||
|
sorted=sorted,
|
||||||
|
)
|
||||||
|
llama_cpp.llama_sample_repetition_penalty(
|
||||||
|
ctx=self.ctx,
|
||||||
|
last_tokens_data=last_n_tokens_data,
|
||||||
|
last_tokens_size=last_n_tokens_size,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
penalty=repeat_penalty,
|
||||||
|
)
|
||||||
|
if temp == 0.0:
|
||||||
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llama_cpp.llama_sample_top_k(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
k=top_k,
|
||||||
|
)
|
||||||
|
llama_cpp.llama_sample_tail_free(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
z=llama_cpp.c_float(1.0),
|
||||||
|
)
|
||||||
|
llama_cpp.llama_sample_typical(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
p=llama_cpp.c_float(1.0)
|
||||||
|
)
|
||||||
|
llama_cpp.llama_sample_top_p(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
p=top_p,
|
||||||
|
)
|
||||||
|
llama_cpp.llama_sample_temperature(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
temp=temp,
|
||||||
|
)
|
||||||
|
return llama_cpp.llama_sample_token(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
@ -270,8 +345,7 @@ class Llama:
|
||||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||||
return llama_cpp.llama_sample_top_p_top_k(
|
return self._sample_top_p_top_k(
|
||||||
ctx=self.ctx,
|
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
*last_n_tokens_data
|
*last_n_tokens_data
|
||||||
),
|
),
|
||||||
|
@ -470,15 +544,15 @@ class Llama:
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
|
|
||||||
# Contains multi-byte UTF8
|
# Contains multi-byte UTF8
|
||||||
for k,char in enumerate(all_text[-3:]):
|
for k, char in enumerate(all_text[-3:]):
|
||||||
k = 3 - k
|
k = 3 - k
|
||||||
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||||
# Bitwise AND check
|
# Bitwise AND check
|
||||||
if (num > k and pattern & char == pattern):
|
if num > k and pattern & char == pattern:
|
||||||
multibyte_fix = num - k
|
multibyte_fix = num - k
|
||||||
|
|
||||||
# Stop incomplete bytes from passing
|
# Stop incomplete bytes from passing
|
||||||
if (multibyte_fix > 0):
|
if multibyte_fix > 0:
|
||||||
multibyte_fix -= 1
|
multibyte_fix -= 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -531,7 +605,9 @@ class Llama:
|
||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
|
"text": text[returned_characters:].decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
@ -558,7 +634,8 @@ class Llama:
|
||||||
|
|
||||||
all_tokens = prompt_tokens + completion_tokens
|
all_tokens = prompt_tokens + completion_tokens
|
||||||
all_token_strs = [
|
all_token_strs = [
|
||||||
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
[Llama.logit_to_logprob(logit) for logit in row]
|
[Llama.logit_to_logprob(logit) for logit in row]
|
||||||
|
@ -577,7 +654,9 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
top_logprob = {
|
top_logprob = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
|
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
): logprob
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||||
|
|
|
@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
|
||||||
|
|
||||||
|
|
||||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
|
def llama_sample_top_k(
|
||||||
|
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
|
||||||
|
):
|
||||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||||
|
|
||||||
|
|
||||||
|
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
|
||||||
llama_context_p,
|
llama_context_p,
|
||||||
llama_token_data_array_p,
|
llama_token_data_array_p,
|
||||||
c_int,
|
c_int,
|
||||||
c_int,
|
c_size_t,
|
||||||
]
|
]
|
||||||
_lib.llama_sample_top_k.restype = None
|
_lib.llama_sample_top_k.restype = None
|
||||||
|
|
||||||
|
|
||||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
def llama_sample_top_p(
|
||||||
|
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||||
|
):
|
||||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||||
|
|
||||||
|
|
||||||
|
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
|
||||||
llama_context_p,
|
llama_context_p,
|
||||||
llama_token_data_array_p,
|
llama_token_data_array_p,
|
||||||
c_float,
|
c_float,
|
||||||
c_int,
|
c_size_t,
|
||||||
]
|
]
|
||||||
_lib.llama_sample_top_p.restype = None
|
_lib.llama_sample_top_p.restype = None
|
||||||
|
|
||||||
|
|
||||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||||
def llama_sample_tail_free(
|
def llama_sample_tail_free(
|
||||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
|
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
|
||||||
):
|
):
|
||||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||||
|
|
||||||
|
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
|
||||||
llama_context_p,
|
llama_context_p,
|
||||||
llama_token_data_array_p,
|
llama_token_data_array_p,
|
||||||
c_float,
|
c_float,
|
||||||
c_int,
|
c_size_t,
|
||||||
]
|
]
|
||||||
_lib.llama_sample_tail_free.restype = None
|
_lib.llama_sample_tail_free.restype = None
|
||||||
|
|
||||||
|
|
||||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
def llama_sample_typical(
|
||||||
|
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||||
|
):
|
||||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||||
|
|
||||||
|
|
||||||
|
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
|
||||||
llama_context_p,
|
llama_context_p,
|
||||||
llama_token_data_array_p,
|
llama_token_data_array_p,
|
||||||
c_float,
|
c_float,
|
||||||
c_int,
|
c_size_t,
|
||||||
]
|
]
|
||||||
_lib.llama_sample_typical.restype = None
|
_lib.llama_sample_typical.restype = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue