Fix llama_cpp and Llama type signatures. Closes #221
This commit is contained in:
parent
fb57b9470b
commit
01a010be52
3 changed files with 58 additions and 64 deletions
|
@ -15,9 +15,7 @@ class LlamaCache:
|
||||||
"""Cache for a llama.cpp model."""
|
"""Cache for a llama.cpp model."""
|
||||||
|
|
||||||
def __init__(self, capacity_bytes: int = (2 << 30)):
|
def __init__(self, capacity_bytes: int = (2 << 30)):
|
||||||
self.cache_state: OrderedDict[
|
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
|
||||||
Tuple[llama_cpp.llama_token, ...], "LlamaState"
|
|
||||||
] = OrderedDict()
|
|
||||||
self.capacity_bytes = capacity_bytes
|
self.capacity_bytes = capacity_bytes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -26,8 +24,8 @@ class LlamaCache:
|
||||||
|
|
||||||
def _find_longest_prefix_key(
|
def _find_longest_prefix_key(
|
||||||
self,
|
self,
|
||||||
key: Tuple[llama_cpp.llama_token, ...],
|
key: Tuple[int, ...],
|
||||||
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
|
) -> Optional[Tuple[int, ...]]:
|
||||||
min_len = 0
|
min_len = 0
|
||||||
min_key = None
|
min_key = None
|
||||||
keys = (
|
keys = (
|
||||||
|
@ -39,7 +37,7 @@ class LlamaCache:
|
||||||
min_key = k
|
min_key = k
|
||||||
return min_key
|
return min_key
|
||||||
|
|
||||||
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
|
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
|
||||||
key = tuple(key)
|
key = tuple(key)
|
||||||
_key = self._find_longest_prefix_key(key)
|
_key = self._find_longest_prefix_key(key)
|
||||||
if _key is None:
|
if _key is None:
|
||||||
|
@ -48,10 +46,10 @@ class LlamaCache:
|
||||||
self.cache_state.move_to_end(_key)
|
self.cache_state.move_to_end(_key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
def __contains__(self, key: Sequence[int]) -> bool:
|
||||||
return self._find_longest_prefix_key(tuple(key)) is not None
|
return self._find_longest_prefix_key(tuple(key)) is not None
|
||||||
|
|
||||||
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
|
||||||
key = tuple(key)
|
key = tuple(key)
|
||||||
if key in self.cache_state:
|
if key in self.cache_state:
|
||||||
del self.cache_state[key]
|
del self.cache_state[key]
|
||||||
|
@ -63,7 +61,7 @@ class LlamaCache:
|
||||||
class LlamaState:
|
class LlamaState:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eval_tokens: Deque[llama_cpp.llama_token],
|
eval_tokens: Deque[int],
|
||||||
eval_logits: Deque[List[float]],
|
eval_logits: Deque[List[float]],
|
||||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||||
llama_state_size: int,
|
llama_state_size: int,
|
||||||
|
@ -141,7 +139,7 @@ class Llama:
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
|
||||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
|
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
|
||||||
|
|
||||||
self.cache: Optional[LlamaCache] = None
|
self.cache: Optional[LlamaCache] = None
|
||||||
|
@ -176,9 +174,7 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
self, text: bytes, add_bos: bool = True
|
|
||||||
) -> List[llama_cpp.llama_token]:
|
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -197,7 +193,7 @@ class Llama:
|
||||||
self.ctx,
|
self.ctx,
|
||||||
text,
|
text,
|
||||||
tokens,
|
tokens,
|
||||||
n_ctx,
|
llama_cpp.c_int(n_ctx),
|
||||||
llama_cpp.c_bool(add_bos),
|
llama_cpp.c_bool(add_bos),
|
||||||
)
|
)
|
||||||
if int(n_tokens) < 0:
|
if int(n_tokens) < 0:
|
||||||
|
@ -216,7 +212,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
return list(tokens[:n_tokens])
|
return list(tokens[:n_tokens])
|
||||||
|
|
||||||
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
|
def detokenize(self, tokens: List[int]) -> bytes:
|
||||||
"""Detokenize a list of tokens.
|
"""Detokenize a list of tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -228,7 +224,9 @@ class Llama:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
output = b""
|
output = b""
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
output += llama_cpp.llama_token_to_str(
|
||||||
|
self.ctx, llama_cpp.llama_token(token)
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def set_cache(self, cache: Optional[LlamaCache]):
|
def set_cache(self, cache: Optional[LlamaCache]):
|
||||||
|
@ -244,7 +242,7 @@ class Llama:
|
||||||
self.eval_tokens.clear()
|
self.eval_tokens.clear()
|
||||||
self.eval_logits.clear()
|
self.eval_logits.clear()
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
def eval(self, tokens: Sequence[int]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -458,7 +456,7 @@ class Llama:
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
tokens: Sequence[llama_cpp.llama_token],
|
tokens: Sequence[int],
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
temp: float = 0.80,
|
temp: float = 0.80,
|
||||||
|
@ -470,9 +468,7 @@ class Llama:
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
) -> Generator[
|
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
|
||||||
]:
|
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -617,14 +613,14 @@ class Llama:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
completion_tokens: List[llama_cpp.llama_token] = []
|
completion_tokens: List[int] = []
|
||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
|
prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8"))
|
||||||
b" " + prompt.encode("utf-8")
|
|
||||||
)
|
|
||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_tokens: int = 0
|
returned_tokens: int = 0
|
||||||
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
|
stop = (
|
||||||
|
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
|
||||||
|
)
|
||||||
model_name: str = model if model is not None else self.model_path
|
model_name: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -724,7 +720,9 @@ class Llama:
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
token_end_position += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token]))
|
||||||
# Check if stop sequence is in the token
|
# Check if stop sequence is in the token
|
||||||
if token_end_position >= (remaining_length - first_stop_position - 1):
|
if token_end_position >= (
|
||||||
|
remaining_length - first_stop_position - 1
|
||||||
|
):
|
||||||
break
|
break
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
|
@ -744,7 +742,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
top_logprob = {
|
top_logprob = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
self.detokenize([i]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
): logprob
|
): logprob
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
@ -822,9 +820,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
top_logprob = {
|
top_logprob = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
|
||||||
"utf-8", errors="ignore"
|
|
||||||
): logprob
|
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||||
|
@ -924,9 +920,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
top_logprob: Optional[Dict[str, float]] = {
|
top_logprob: Optional[Dict[str, float]] = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
|
||||||
"utf-8", errors="ignore"
|
|
||||||
): logprob
|
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||||
|
@ -1188,7 +1182,9 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
Generated chat completion or a stream of chat completion chunks.
|
Generated chat completion or a stream of chat completion chunks.
|
||||||
"""
|
"""
|
||||||
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
|
stop = (
|
||||||
|
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
|
||||||
|
)
|
||||||
chat_history = "".join(
|
chat_history = "".join(
|
||||||
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
|
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
|
||||||
for message in messages
|
for message in messages
|
||||||
|
@ -1296,17 +1292,17 @@ class Llama:
|
||||||
raise RuntimeError("Failed to set llama state data")
|
raise RuntimeError("Failed to set llama state data")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def token_eos() -> llama_cpp.llama_token:
|
def token_eos() -> int:
|
||||||
"""Return the end-of-sequence token."""
|
"""Return the end-of-sequence token."""
|
||||||
return llama_cpp.llama_token_eos()
|
return llama_cpp.llama_token_eos()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def token_bos() -> llama_cpp.llama_token:
|
def token_bos() -> int:
|
||||||
"""Return the beginning-of-sequence token."""
|
"""Return the beginning-of-sequence token."""
|
||||||
return llama_cpp.llama_token_bos()
|
return llama_cpp.llama_token_bos()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def token_nl() -> llama_cpp.llama_token:
|
def token_nl() -> int:
|
||||||
"""Return the newline token."""
|
"""Return the newline token."""
|
||||||
return llama_cpp.llama_token_nl()
|
return llama_cpp.llama_token_nl()
|
||||||
|
|
||||||
|
@ -1317,9 +1313,7 @@ class Llama:
|
||||||
return [math.log(x / sum_exps) for x in exps]
|
return [math.log(x / sum_exps) for x in exps]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def longest_token_prefix(
|
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
||||||
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
|
|
||||||
):
|
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for _a, _b in zip(a, b):
|
for _a, _b in zip(a, b):
|
||||||
if _a == _b:
|
if _a == _b:
|
||||||
|
|
|
@ -44,13 +44,13 @@ def _load_shared_library(lib_base_name: str):
|
||||||
_base_path = _lib.parent.resolve()
|
_base_path = _lib.parent.resolve()
|
||||||
_lib_paths = [_lib.resolve()]
|
_lib_paths = [_lib.resolve()]
|
||||||
|
|
||||||
cdll_args = dict() # type: ignore
|
cdll_args = dict() # type: ignore
|
||||||
# Add the library directory to the DLL search path on Windows (if needed)
|
# Add the library directory to the DLL search path on Windows (if needed)
|
||||||
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
||||||
os.add_dll_directory(str(_base_path))
|
os.add_dll_directory(str(_base_path))
|
||||||
if "CUDA_PATH" in os.environ:
|
if "CUDA_PATH" in os.environ:
|
||||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"],"bin"))
|
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
|
||||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"],"lib"))
|
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
|
||||||
cdll_args["winmode"] = 0
|
cdll_args["winmode"] = 0
|
||||||
|
|
||||||
# Try to load the shared library, handling potential errors
|
# Try to load the shared library, handling potential errors
|
||||||
|
@ -194,7 +194,7 @@ _lib.llama_init_from_file.restype = llama_context_p
|
||||||
|
|
||||||
# Frees all allocated memory
|
# Frees all allocated memory
|
||||||
def llama_free(ctx: llama_context_p):
|
def llama_free(ctx: llama_context_p):
|
||||||
_lib.llama_free(ctx)
|
return _lib.llama_free(ctx)
|
||||||
|
|
||||||
|
|
||||||
_lib.llama_free.argtypes = [llama_context_p]
|
_lib.llama_free.argtypes = [llama_context_p]
|
||||||
|
@ -206,7 +206,7 @@ _lib.llama_free.restype = None
|
||||||
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||||
def llama_model_quantize(
|
def llama_model_quantize(
|
||||||
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
|
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
|
||||||
) -> c_int:
|
) -> int:
|
||||||
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
|
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
|
||||||
|
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ def llama_apply_lora_from_file(
|
||||||
path_lora: c_char_p,
|
path_lora: c_char_p,
|
||||||
path_base_model: c_char_p,
|
path_base_model: c_char_p,
|
||||||
n_threads: c_int,
|
n_threads: c_int,
|
||||||
) -> c_int:
|
) -> int:
|
||||||
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
|
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
|
||||||
|
|
||||||
|
|
||||||
|
@ -234,7 +234,7 @@ _lib.llama_apply_lora_from_file.restype = c_int
|
||||||
|
|
||||||
|
|
||||||
# Returns the number of tokens in the KV cache
|
# Returns the number of tokens in the KV cache
|
||||||
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
|
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
|
||||||
return _lib.llama_get_kv_cache_token_count(ctx)
|
return _lib.llama_get_kv_cache_token_count(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ _lib.llama_set_rng_seed.restype = None
|
||||||
|
|
||||||
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
# Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
# and kv_cache) - will often be smaller after compacting tokens
|
# and kv_cache) - will often be smaller after compacting tokens
|
||||||
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
|
def llama_get_state_size(ctx: llama_context_p) -> int:
|
||||||
return _lib.llama_get_state_size(ctx)
|
return _lib.llama_get_state_size(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@ -293,7 +293,7 @@ def llama_load_session_file(
|
||||||
tokens_out, # type: Array[llama_token]
|
tokens_out, # type: Array[llama_token]
|
||||||
n_token_capacity: c_size_t,
|
n_token_capacity: c_size_t,
|
||||||
n_token_count_out, # type: _Pointer[c_size_t]
|
n_token_count_out, # type: _Pointer[c_size_t]
|
||||||
) -> c_size_t:
|
) -> int:
|
||||||
return _lib.llama_load_session_file(
|
return _lib.llama_load_session_file(
|
||||||
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
|
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
|
||||||
)
|
)
|
||||||
|
@ -314,7 +314,7 @@ def llama_save_session_file(
|
||||||
path_session: bytes,
|
path_session: bytes,
|
||||||
tokens, # type: Array[llama_token]
|
tokens, # type: Array[llama_token]
|
||||||
n_token_count: c_size_t,
|
n_token_count: c_size_t,
|
||||||
) -> c_size_t:
|
) -> int:
|
||||||
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
|
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
|
||||||
|
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ def llama_eval(
|
||||||
n_tokens: c_int,
|
n_tokens: c_int,
|
||||||
n_past: c_int,
|
n_past: c_int,
|
||||||
n_threads: c_int,
|
n_threads: c_int,
|
||||||
) -> c_int:
|
) -> int:
|
||||||
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
|
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
|
||||||
|
|
||||||
|
|
||||||
|
@ -364,7 +364,7 @@ _lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int,
|
||||||
_lib.llama_tokenize.restype = c_int
|
_lib.llama_tokenize.restype = c_int
|
||||||
|
|
||||||
|
|
||||||
def llama_n_vocab(ctx: llama_context_p) -> c_int:
|
def llama_n_vocab(ctx: llama_context_p) -> int:
|
||||||
return _lib.llama_n_vocab(ctx)
|
return _lib.llama_n_vocab(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ _lib.llama_n_vocab.argtypes = [llama_context_p]
|
||||||
_lib.llama_n_vocab.restype = c_int
|
_lib.llama_n_vocab.restype = c_int
|
||||||
|
|
||||||
|
|
||||||
def llama_n_ctx(ctx: llama_context_p) -> c_int:
|
def llama_n_ctx(ctx: llama_context_p) -> int:
|
||||||
return _lib.llama_n_ctx(ctx)
|
return _lib.llama_n_ctx(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ _lib.llama_n_ctx.argtypes = [llama_context_p]
|
||||||
_lib.llama_n_ctx.restype = c_int
|
_lib.llama_n_ctx.restype = c_int
|
||||||
|
|
||||||
|
|
||||||
def llama_n_embd(ctx: llama_context_p) -> c_int:
|
def llama_n_embd(ctx: llama_context_p) -> int:
|
||||||
return _lib.llama_n_embd(ctx)
|
return _lib.llama_n_embd(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@ -426,7 +426,7 @@ _lib.llama_token_to_str.restype = c_char_p
|
||||||
# Special tokens
|
# Special tokens
|
||||||
|
|
||||||
|
|
||||||
def llama_token_bos() -> llama_token:
|
def llama_token_bos() -> int:
|
||||||
return _lib.llama_token_bos()
|
return _lib.llama_token_bos()
|
||||||
|
|
||||||
|
|
||||||
|
@ -434,7 +434,7 @@ _lib.llama_token_bos.argtypes = []
|
||||||
_lib.llama_token_bos.restype = llama_token
|
_lib.llama_token_bos.restype = llama_token
|
||||||
|
|
||||||
|
|
||||||
def llama_token_eos() -> llama_token:
|
def llama_token_eos() -> int:
|
||||||
return _lib.llama_token_eos()
|
return _lib.llama_token_eos()
|
||||||
|
|
||||||
|
|
||||||
|
@ -442,7 +442,7 @@ _lib.llama_token_eos.argtypes = []
|
||||||
_lib.llama_token_eos.restype = llama_token
|
_lib.llama_token_eos.restype = llama_token
|
||||||
|
|
||||||
|
|
||||||
def llama_token_nl() -> llama_token:
|
def llama_token_nl() -> int:
|
||||||
return _lib.llama_token_nl()
|
return _lib.llama_token_nl()
|
||||||
|
|
||||||
|
|
||||||
|
@ -625,7 +625,7 @@ def llama_sample_token_mirostat(
|
||||||
eta: c_float,
|
eta: c_float,
|
||||||
m: c_int,
|
m: c_int,
|
||||||
mu, # type: _Pointer[c_float]
|
mu, # type: _Pointer[c_float]
|
||||||
) -> llama_token:
|
) -> int:
|
||||||
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
|
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
|
||||||
|
|
||||||
|
|
||||||
|
@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2(
|
||||||
tau: c_float,
|
tau: c_float,
|
||||||
eta: c_float,
|
eta: c_float,
|
||||||
mu, # type: _Pointer[c_float]
|
mu, # type: _Pointer[c_float]
|
||||||
) -> llama_token:
|
) -> int:
|
||||||
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
|
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
|
||||||
|
|
||||||
|
|
||||||
|
@ -669,7 +669,7 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
|
||||||
def llama_sample_token_greedy(
|
def llama_sample_token_greedy(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates, # type: _Pointer[llama_token_data_array]
|
||||||
) -> llama_token:
|
) -> int:
|
||||||
return _lib.llama_sample_token_greedy(ctx, candidates)
|
return _lib.llama_sample_token_greedy(ctx, candidates)
|
||||||
|
|
||||||
|
|
||||||
|
@ -684,7 +684,7 @@ _lib.llama_sample_token_greedy.restype = llama_token
|
||||||
def llama_sample_token(
|
def llama_sample_token(
|
||||||
ctx: llama_context_p,
|
ctx: llama_context_p,
|
||||||
candidates, # type: _Pointer[llama_token_data_array]
|
candidates, # type: _Pointer[llama_token_data_array]
|
||||||
) -> llama_token:
|
) -> int:
|
||||||
return _lib.llama_sample_token(ctx, candidates)
|
return _lib.llama_sample_token(ctx, candidates)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ def test_llama():
|
||||||
# @pytest.mark.skip(reason="need to update sample mocking")
|
# @pytest.mark.skip(reason="need to update sample mocking")
|
||||||
def test_llama_patch(monkeypatch):
|
def test_llama_patch(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
n_vocab = llama_cpp.llama_n_vocab(llama.ctx)
|
||||||
|
|
||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_eval(*args, **kwargs):
|
def mock_eval(*args, **kwargs):
|
||||||
|
@ -107,7 +107,7 @@ def test_llama_pickle():
|
||||||
|
|
||||||
def test_utf8(monkeypatch):
|
def test_utf8(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
n_vocab = llama_cpp.llama_n_vocab(llama.ctx)
|
||||||
|
|
||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_eval(*args, **kwargs):
|
def mock_eval(*args, **kwargs):
|
||||||
|
|
Loading…
Reference in a new issue