Added grammar based sampling

This commit is contained in:
c0sogi 2023-08-07 02:21:37 +09:00
parent ac188a21f3
commit 418aa83b01
2 changed files with 537 additions and 543 deletions

View file

@ -1,4 +1,5 @@
import os import os
from pathlib import Path
import sys import sys
import uuid import uuid
import time import time
@ -23,6 +24,7 @@ import ctypes
from . import llama_cpp from . import llama_cpp
from .llama_types import * from .llama_types import *
from .llama_grammar import LlamaGrammar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -223,6 +225,7 @@ class Llama:
tensor_split: Optional[List[float]] = None, tensor_split: Optional[List[float]] = None,
rope_freq_base: float = 10000.0, rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0, rope_freq_scale: float = 1.0,
grammar: Optional[Union[str, Path]] = None,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY) rms_norm_eps: Optional[float] = None, # (TEMPORARY)
verbose: bool = True, verbose: bool = True,
@ -248,6 +251,7 @@ class Llama:
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
rope_freq_base: Base frequency for rope sampling. rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling. rope_freq_scale: Scale factor for rope sampling.
grammar: Path to a BNF grammar file to use for grammar based sampling.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -358,6 +362,12 @@ class Llama:
self.scores: npt.NDArray[np.single] = np.ndarray( self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single (n_ctx, self._n_vocab), dtype=np.single
) )
if grammar is not None:
self.grammar = LlamaGrammar.from_file(
grammar
) # type: Optional[LlamaGrammar]
else:
self.grammar = None
@property @property
def _input_ids(self) -> npt.NDArray[np.intc]: def _input_ids(self) -> npt.NDArray[np.intc]:
@ -542,8 +552,16 @@ class Llama:
) )
if not penalize_nl: if not penalize_nl:
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if self.grammar is not None:
llama_cpp.llama_sample_grammar(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
grammar=self.grammar.grammar,
)
if temp.value == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( id = llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
@ -555,7 +573,7 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token_mirostat( id = llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau, tau=mirostat_tau,
@ -570,7 +588,7 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token_mirostat_v2( id = llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau, tau=mirostat_tau,
@ -607,10 +625,17 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp, temp=temp,
) )
return llama_cpp.llama_sample_token( id = llama_cpp.llama_sample_token(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
if self.grammar is not None:
llama_cpp.llama_grammar_accept_token(
ctx=self.ctx,
grammar=self.grammar.grammar,
token=llama_cpp.ctypes.c_int(id),
)
return id
def sample( def sample(
self, self,
@ -1509,6 +1534,9 @@ class Llama:
if self.ctx is not None: if self.ctx is not None:
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.ctx = None self.ctx = None
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar.grammar)
self.grammar = None
def __getstate__(self): def __getstate__(self):
return dict( return dict(

File diff suppressed because it is too large Load diff