Implement sampling as in llama.cpp main example

This commit is contained in:
Andrei Betlen 2023-05-08 21:21:25 -04:00
parent 93a9019bb1
commit d957422bf4

View file

@ -268,14 +268,13 @@ class Llama:
top_k: llama_cpp.c_int, top_k: llama_cpp.c_int,
top_p: llama_cpp.c_float, top_p: llama_cpp.c_float,
temp: llama_cpp.c_float, temp: llama_cpp.c_float,
mirostat_mode: llama_cpp.c_int, tfs_z: llama_cpp.c_float,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
mirostat_mu: llama_cpp.c_float,
mirostat_m: llama_cpp.c_int,
repeat_penalty: llama_cpp.c_float, repeat_penalty: llama_cpp.c_float,
frequency_penalty: llama_cpp.c_float, frequency_penalty: llama_cpp.c_float,
presence_penalty: llama_cpp.c_float, presence_penalty: llama_cpp.c_float,
mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
@ -305,33 +304,6 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty, penalty=repeat_penalty,
) )
if mirostat_mode.value == 1:
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
m=mirostat_m
)
elif mirostat_mode.value == 2:
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore
)
llama_cpp.llama_sample_frequency_and_presence_penalties( llama_cpp.llama_sample_frequency_and_presence_penalties(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
@ -340,11 +312,41 @@ class Llama:
alpha_frequency=frequency_penalty, alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty, alpha_presence=presence_penalty,
) )
if float(temp.value) == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
elif mirostat_mode.value == 1:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
mirostat_m = llama_cpp.c_int(100)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
m=mirostat_m,
)
elif mirostat_mode.value == 2:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
)
else: else:
llama_cpp.llama_sample_top_k( llama_cpp.llama_sample_top_k(
ctx=self.ctx, ctx=self.ctx,
@ -355,7 +357,7 @@ class Llama:
llama_cpp.llama_sample_tail_free( llama_cpp.llama_sample_tail_free(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
z=llama_cpp.c_float(1.0), z=tfs_z,
min_keep=llama_cpp.c_size_t(1), min_keep=llama_cpp.c_size_t(1),
) )
llama_cpp.llama_sample_typical( llama_cpp.llama_sample_typical(
@ -382,17 +384,16 @@ class Llama:
def sample( def sample(
self, self,
top_k: int, top_k: int = 40,
top_p: float, top_p: float = 0.95,
temp: float, temp: float = 0.80,
mirostat_mode: int, repeat_penalty: float = 1.1,
mirostat_tau: float,
mirostat_eta: float,
mirostat_mu: float,
mirostat_m: int,
repeat_penalty: float,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -417,14 +418,13 @@ class Llama:
top_k=llama_cpp.c_int(top_k), top_k=llama_cpp.c_int(top_k),
top_p=llama_cpp.c_float(top_p), top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp), temp=llama_cpp.c_float(temp),
mirostat_mode=llama_cpp.c_int(mirostat_mode), tfs_z=llama_cpp.c_float(tfs_z),
mirostat_mu=llama_cpp.c_float(mirostat_mu),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
mirostat_m=llama_cpp.c_int(mirostat_m),
repeat_penalty=llama_cpp.c_float(repeat_penalty), repeat_penalty=llama_cpp.c_float(repeat_penalty),
frequency_penalty=llama_cpp.c_float(frequency_penalty), frequency_penalty=llama_cpp.c_float(frequency_penalty),
presence_penalty=llama_cpp.c_float(presence_penalty), presence_penalty=llama_cpp.c_float(presence_penalty),
mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
) )
def generate( def generate(
@ -433,15 +433,13 @@ class Llama:
top_k: int, top_k: int,
top_p: float, top_p: float,
temp: float, temp: float,
mirostat_mode: int,
mirostat_tau: float,
mirostat_eta: float,
mirostat_mu: float,
mirostat_m: int,
repeat_penalty: float, repeat_penalty: float,
reset: bool = True,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
reset: bool = True, mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Generator[ ) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]: ]:
@ -494,14 +492,12 @@ class Llama:
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temp, temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
) )
tokens_or_none = yield token tokens_or_none = yield token
tokens = [token] tokens = [token]
@ -571,11 +567,6 @@ class Llama:
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 16, max_tokens: int = 16,
temperature: float = 0.8, temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
@ -585,6 +576,9 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -643,8 +637,6 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
@ -817,11 +809,6 @@ class Llama:
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 128, max_tokens: int = 128,
temperature: float = 0.8, temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
@ -831,6 +818,9 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -859,11 +849,6 @@ class Llama:
suffix=suffix, suffix=suffix,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
top_p=top_p, top_p=top_p,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
@ -873,6 +858,9 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -886,11 +874,6 @@ class Llama:
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 128, max_tokens: int = 128,
temperature: float = 0.8, temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
@ -900,6 +883,9 @@ class Llama:
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -928,11 +914,6 @@ class Llama:
suffix=suffix, suffix=suffix,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
top_p=top_p, top_p=top_p,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
@ -942,6 +923,9 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
) )
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(
@ -1014,6 +998,9 @@ class Llama:
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -1048,6 +1035,9 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore