fix: format
This commit is contained in:
parent
fe5d6ea648
commit
4d6b2f7b91
1 changed files with 13 additions and 8 deletions
|
@ -197,13 +197,14 @@ class Llama:
|
||||||
# kv_overrides is the original python dict
|
# kv_overrides is the original python dict
|
||||||
self.kv_overrides = kv_overrides
|
self.kv_overrides = kv_overrides
|
||||||
if kv_overrides is not None:
|
if kv_overrides is not None:
|
||||||
|
|
||||||
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
|
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
|
||||||
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
|
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
|
||||||
self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)()
|
self._kv_overrides_array = (
|
||||||
|
llama_cpp.llama_model_kv_override * kvo_array_len
|
||||||
|
)()
|
||||||
|
|
||||||
for i, (k, v) in enumerate(kv_overrides.items()):
|
for i, (k, v) in enumerate(kv_overrides.items()):
|
||||||
self._kv_overrides_array[i].key = k.encode('utf-8');
|
self._kv_overrides_array[i].key = k.encode("utf-8")
|
||||||
if isinstance(v, int):
|
if isinstance(v, int):
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
|
||||||
self._kv_overrides_array[i].value.int_value = v
|
self._kv_overrides_array[i].value.int_value = v
|
||||||
|
@ -216,7 +217,9 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||||
|
|
||||||
self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed
|
self._kv_overrides_array[
|
||||||
|
-1
|
||||||
|
].key = b"\0" # ensure sentinel element is zeroed
|
||||||
self.model_params.kv_overrides = self._kv_overrides_array
|
self.model_params.kv_overrides = self._kv_overrides_array
|
||||||
|
|
||||||
self.n_batch = min(n_ctx, n_batch) # ???
|
self.n_batch = min(n_ctx, n_batch) # ???
|
||||||
|
@ -326,7 +329,9 @@ class Llama:
|
||||||
(n_ctx, self._n_vocab), dtype=np.single
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
)
|
)
|
||||||
|
|
||||||
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
|
self._mirostat_mu = ctypes.c_float(
|
||||||
|
2.0 * 5.0
|
||||||
|
) # TODO: Move this to sampling context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.metadata = self._model.metadata()
|
self.metadata = self._model.metadata()
|
||||||
|
@ -334,7 +339,7 @@ class Llama:
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Failed to load metadata: {e}", file=sys.stderr)
|
print(f"Failed to load metadata: {e}", file=sys.stderr)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||||
|
|
||||||
|
@ -534,7 +539,7 @@ class Llama:
|
||||||
candidates=self._candidates,
|
candidates=self._candidates,
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
eta=mirostat_eta,
|
eta=mirostat_eta,
|
||||||
mu=ctypes.pointer(self._mirostat_mu)
|
mu=ctypes.pointer(self._mirostat_mu),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||||
|
|
Loading…
Reference in a new issue