Added grammar based sampling
This commit is contained in:
parent
ac188a21f3
commit
418aa83b01
2 changed files with 537 additions and 543 deletions
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
@ -23,6 +24,7 @@ import ctypes
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
from .llama_types import *
|
from .llama_types import *
|
||||||
|
from .llama_grammar import LlamaGrammar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
@ -223,6 +225,7 @@ class Llama:
|
||||||
tensor_split: Optional[List[float]] = None,
|
tensor_split: Optional[List[float]] = None,
|
||||||
rope_freq_base: float = 10000.0,
|
rope_freq_base: float = 10000.0,
|
||||||
rope_freq_scale: float = 1.0,
|
rope_freq_scale: float = 1.0,
|
||||||
|
grammar: Optional[Union[str, Path]] = None,
|
||||||
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
||||||
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
@ -248,6 +251,7 @@ class Llama:
|
||||||
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
|
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
|
||||||
rope_freq_base: Base frequency for rope sampling.
|
rope_freq_base: Base frequency for rope sampling.
|
||||||
rope_freq_scale: Scale factor for rope sampling.
|
rope_freq_scale: Scale factor for rope sampling.
|
||||||
|
grammar: Path to a BNF grammar file to use for grammar based sampling.
|
||||||
verbose: Print verbose output to stderr.
|
verbose: Print verbose output to stderr.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -358,6 +362,12 @@ class Llama:
|
||||||
self.scores: npt.NDArray[np.single] = np.ndarray(
|
self.scores: npt.NDArray[np.single] = np.ndarray(
|
||||||
(n_ctx, self._n_vocab), dtype=np.single
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
)
|
)
|
||||||
|
if grammar is not None:
|
||||||
|
self.grammar = LlamaGrammar.from_file(
|
||||||
|
grammar
|
||||||
|
) # type: Optional[LlamaGrammar]
|
||||||
|
else:
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _input_ids(self) -> npt.NDArray[np.intc]:
|
def _input_ids(self) -> npt.NDArray[np.intc]:
|
||||||
|
@ -542,8 +552,16 @@ class Llama:
|
||||||
)
|
)
|
||||||
if not penalize_nl:
|
if not penalize_nl:
|
||||||
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
|
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
|
||||||
|
|
||||||
|
if self.grammar is not None:
|
||||||
|
llama_cpp.llama_sample_grammar(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
grammar=self.grammar.grammar,
|
||||||
|
)
|
||||||
|
|
||||||
if temp.value == 0.0:
|
if temp.value == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
id = llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -555,7 +573,7 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
temp=temp,
|
temp=temp,
|
||||||
)
|
)
|
||||||
return llama_cpp.llama_sample_token_mirostat(
|
id = llama_cpp.llama_sample_token_mirostat(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
|
@ -570,7 +588,7 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
temp=temp,
|
temp=temp,
|
||||||
)
|
)
|
||||||
return llama_cpp.llama_sample_token_mirostat_v2(
|
id = llama_cpp.llama_sample_token_mirostat_v2(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
|
@ -607,10 +625,17 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
temp=temp,
|
temp=temp,
|
||||||
)
|
)
|
||||||
return llama_cpp.llama_sample_token(
|
id = llama_cpp.llama_sample_token(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
)
|
)
|
||||||
|
if self.grammar is not None:
|
||||||
|
llama_cpp.llama_grammar_accept_token(
|
||||||
|
ctx=self.ctx,
|
||||||
|
grammar=self.grammar.grammar,
|
||||||
|
token=llama_cpp.ctypes.c_int(id),
|
||||||
|
)
|
||||||
|
return id
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
@ -1509,6 +1534,9 @@ class Llama:
|
||||||
if self.ctx is not None:
|
if self.ctx is not None:
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
|
if self.grammar is not None:
|
||||||
|
llama_cpp.llama_grammar_free(self.grammar.grammar)
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return dict(
|
return dict(
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue