This commit is contained in:
commit
4cb67f59d8
4 changed files with 147 additions and 122 deletions
|
@ -79,6 +79,7 @@ class Llama:
|
|||
n_threads: Optional[int] = None,
|
||||
n_threads_batch: Optional[int] = None,
|
||||
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
rope_freq_base: float = 0.0,
|
||||
rope_freq_scale: float = 0.0,
|
||||
yarn_ext_factor: float = -1.0,
|
||||
|
@ -151,6 +152,7 @@ class Llama:
|
|||
n_threads: Number of threads to use for generation
|
||||
n_threads_batch: Number of threads to use for batch processing
|
||||
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
pooling_type: Pooling type, from `enum llama_pooling_type`.
|
||||
rope_freq_base: RoPE base frequency, 0 = from model
|
||||
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
|
||||
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
|
||||
|
@ -271,6 +273,7 @@ class Llama:
|
|||
if rope_scaling_type is not None
|
||||
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||
)
|
||||
self.context_params.pooling_type = pooling_type
|
||||
self.context_params.rope_freq_base = (
|
||||
rope_freq_base if rope_freq_base != 0.0 else 0
|
||||
)
|
||||
|
@ -814,9 +817,12 @@ class Llama:
|
|||
|
||||
# store embeddings
|
||||
for i in range(n_seq):
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
|
||||
ptr = llama_cpp.llama_get_embeddings_seq(
|
||||
self._ctx.ctx, i
|
||||
)[:n_embd]
|
||||
)
|
||||
if not ptr:
|
||||
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
|
||||
embedding: List[float] = ptr[:n_embd]
|
||||
if normalize:
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
embedding = [v / norm for v in embedding]
|
||||
|
|
|
@ -339,16 +339,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
stop = stop + rstop
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
# create grammar from json schema
|
||||
if "schema" in response_format:
|
||||
grammar = llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"]), verbose=llama.verbose
|
||||
)
|
||||
except Exception as e:
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF, verbose=llama.verbose
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
|
@ -606,6 +597,35 @@ def _format_chatglm3(
|
|||
ret += role
|
||||
return ret
|
||||
|
||||
def _grammar_for_json(verbose:bool=False):
|
||||
return llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF, verbose=verbose)
|
||||
|
||||
def _grammar_for_json_schema(
|
||||
schema: str,
|
||||
verbose: bool = False,
|
||||
fallback_to_json: bool = True
|
||||
):
|
||||
try:
|
||||
return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose)
|
||||
except Exception as e:
|
||||
if fallback_to_json:
|
||||
return _grammar_for_json(verbose=verbose)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _grammar_for_response_format(
|
||||
response_format: llama_types.ChatCompletionRequestResponseFormat,
|
||||
verbose: bool = False
|
||||
):
|
||||
if response_format["type"] != "json_object":
|
||||
return None
|
||||
|
||||
if "schema" in response_format:
|
||||
return _grammar_for_json_schema(
|
||||
json.dumps(response_format["schema"]), verbose=verbose
|
||||
)
|
||||
else:
|
||||
return _grammar_for_json(verbose=verbose)
|
||||
|
||||
### Chat Formats ###
|
||||
|
||||
|
@ -1994,16 +2014,7 @@ class Llava15ChatHandler:
|
|||
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
# create grammar from json schema
|
||||
if "schema" in response_format:
|
||||
grammar = llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"])
|
||||
)
|
||||
except Exception as e:
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format)
|
||||
|
||||
return _convert_completion_to_chat(
|
||||
llama.create_completion(
|
||||
|
@ -2159,26 +2170,10 @@ def chatml_function_calling(
|
|||
tool_calls=None,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
grammar = (
|
||||
llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"])
|
||||
)
|
||||
if "schema" in response_format
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
if llama.verbose:
|
||||
print(
|
||||
"Failed to parse response format as JSON schema, falling back to default grammar"
|
||||
)
|
||||
print(e)
|
||||
grammar = (
|
||||
llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||
if grammar is None
|
||||
else grammar
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format)
|
||||
|
||||
return _convert_completion_to_chat(
|
||||
llama.create_completion(
|
||||
prompt=prompt,
|
||||
|
|
|
@ -198,13 +198,15 @@ llama_seq_id = ctypes.c_int32
|
|||
|
||||
|
||||
# enum llama_vocab_type {
|
||||
# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
# LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
||||
# LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
# LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece
|
||||
# LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding
|
||||
# LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece
|
||||
# };
|
||||
LLAMA_VOCAB_TYPE_SPM = 0
|
||||
LLAMA_VOCAB_TYPE_BPE = 1
|
||||
LLAMA_VOCAB_TYPE_WPM = 2
|
||||
LLAMA_VOCAB_TYPE_NONE = 0
|
||||
LLAMA_VOCAB_TYPE_SPM = 1
|
||||
LLAMA_VOCAB_TYPE_BPE = 2
|
||||
LLAMA_VOCAB_TYPE_WPM = 3
|
||||
|
||||
|
||||
# // note: these values should be synchronized with ggml_rope
|
||||
|
@ -548,8 +550,9 @@ class llama_model_params(ctypes.Structure):
|
|||
# struct llama_context_params {
|
||||
# uint32_t seed; // RNG seed, -1 for random
|
||||
# uint32_t n_ctx; // text context, 0 = from model
|
||||
# uint32_t n_batch; // prompt processing maximum batch size
|
||||
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
|
||||
# uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
||||
# uint32_t n_ubatch; // physical maximum batch size
|
||||
# uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
||||
# uint32_t n_threads; // number of threads to use for generation
|
||||
# uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
|
@ -590,8 +593,9 @@ class llama_context_params(ctypes.Structure):
|
|||
Attributes:
|
||||
seed (int): RNG seed, -1 for random
|
||||
n_ctx (int): text context, 0 = from model
|
||||
n_batch (int): prompt processing maximum batch size
|
||||
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
|
||||
n_batch (int): logical maximum batch size that can be submitted to llama_decode
|
||||
n_ubatch (int): physical maximum batch size
|
||||
n_seq_max (int): max number of sequences (i.e. distinct states for recurrent models)
|
||||
n_threads (int): number of threads to use for generation
|
||||
n_threads_batch (int): number of threads to use for batch processing
|
||||
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
@ -619,7 +623,8 @@ class llama_context_params(ctypes.Structure):
|
|||
("seed", ctypes.c_uint32),
|
||||
("n_ctx", ctypes.c_uint32),
|
||||
("n_batch", ctypes.c_uint32),
|
||||
("n_parallel", ctypes.c_uint32),
|
||||
("n_ubatch", ctypes.c_uint32),
|
||||
("n_seq_max", ctypes.c_uint32),
|
||||
("n_threads", ctypes.c_uint32),
|
||||
("n_threads_batch", ctypes.c_uint32),
|
||||
("rope_scaling_type", ctypes.c_int),
|
||||
|
@ -667,7 +672,7 @@ It might not exist for progress report where '.' is output repeatedly."""
|
|||
# bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
# bool quantize_output_tensor; // quantize output.weight
|
||||
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
|
||||
# bool pure; // quantize all tensors to the default type
|
||||
# void * imatrix; // pointer to importance matrix data
|
||||
# } llama_model_quantize_params;
|
||||
class llama_model_quantize_params(ctypes.Structure):
|
||||
|
@ -679,7 +684,7 @@ class llama_model_quantize_params(ctypes.Structure):
|
|||
allow_requantize (bool): allow quantizing non-f32/f16 tensors
|
||||
quantize_output_tensor (bool): quantize output.weight
|
||||
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
pure (bool): disable k-quant mixtures and quantize all tensors to the same type
|
||||
pure (bool): quantize all tensors to the default type
|
||||
imatrix (ctypes.ctypes.c_void_p): pointer to importance matrix data
|
||||
"""
|
||||
|
||||
|
@ -860,8 +865,7 @@ GGML_NUMA_STRATEGY_COUNT = 5
|
|||
[ctypes.c_int],
|
||||
None,
|
||||
)
|
||||
def llama_numa_init(numa: int, /):
|
||||
...
|
||||
def llama_numa_init(numa: int, /): ...
|
||||
|
||||
|
||||
# // Call once at the end of the program - currently only used for MPI
|
||||
|
@ -886,8 +890,7 @@ def llama_backend_free():
|
|||
)
|
||||
def llama_load_model_from_file(
|
||||
path_model: bytes, params: llama_model_params, /
|
||||
) -> Optional[llama_model_p]:
|
||||
...
|
||||
) -> Optional[llama_model_p]: ...
|
||||
|
||||
|
||||
# LLAMA_API void llama_free_model(struct llama_model * model);
|
||||
|
@ -896,8 +899,7 @@ def llama_load_model_from_file(
|
|||
[llama_model_p_ctypes],
|
||||
None,
|
||||
)
|
||||
def llama_free_model(model: llama_model_p, /):
|
||||
...
|
||||
def llama_free_model(model: llama_model_p, /): ...
|
||||
|
||||
|
||||
# LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
|
@ -910,8 +912,7 @@ def llama_free_model(model: llama_model_p, /):
|
|||
)
|
||||
def llama_new_context_with_model(
|
||||
model: llama_model_p, params: llama_context_params, /
|
||||
) -> Optional[llama_context_p]:
|
||||
...
|
||||
) -> Optional[llama_context_p]: ...
|
||||
|
||||
|
||||
# // Frees all allocated memory
|
||||
|
@ -932,80 +933,77 @@ def llama_free(ctx: llama_context_p, /):
|
|||
[],
|
||||
ctypes.c_int64,
|
||||
)
|
||||
def llama_time_us() -> int:
|
||||
...
|
||||
def llama_time_us() -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API size_t llama_max_devices(void);
|
||||
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
|
||||
def llama_max_devices() -> int:
|
||||
...
|
||||
def llama_max_devices() -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_mmap (void);
|
||||
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
|
||||
def llama_supports_mmap() -> bool:
|
||||
...
|
||||
def llama_supports_mmap() -> bool: ...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_mlock (void);
|
||||
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
|
||||
def llama_supports_mlock() -> bool:
|
||||
...
|
||||
def llama_supports_mlock() -> bool: ...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
|
||||
def llama_supports_gpu_offload() -> bool:
|
||||
...
|
||||
def llama_supports_gpu_offload() -> bool: ...
|
||||
|
||||
|
||||
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
|
||||
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
|
||||
...
|
||||
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_ctx(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_batch(ctx: llama_context_p, /) -> int:
|
||||
...
|
||||
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
|
||||
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||
def llama_vocab_type(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_vocab_type(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||
def llama_rope_type(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_rope_type(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_vocab(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_n_vocab(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_ctx_train(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_embd(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_n_embd(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# // Get the model's RoPE frequency scaling factor
|
||||
|
@ -1192,8 +1190,7 @@ def llama_model_apply_lora_from_file(
|
|||
path_base_model: Union[ctypes.c_char_p, bytes, None],
|
||||
n_threads: Union[ctypes.c_int32, int],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
) -> int: ...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -1219,7 +1216,7 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
|||
# // Maximum number of sequences that can exist in a cell. It's not an error
|
||||
# // if there are more sequences in a cell than this value, however they will
|
||||
# // not be visible in the view cells_sequences.
|
||||
# int32_t n_max_seq;
|
||||
# int32_t n_seq_max;
|
||||
|
||||
# // Number of tokens in the cache. For example, if there are two populated
|
||||
# // cells, the first with 1 sequence id in it and the second with 2 sequence
|
||||
|
@ -1240,7 +1237,7 @@ class llama_kv_cache_view_cell(ctypes.Structure):
|
|||
# struct llama_kv_cache_view_cell * cells;
|
||||
|
||||
|
||||
# // The sequences for each cell. There will be n_max_seq items per cell.
|
||||
# // The sequences for each cell. There will be n_seq_max items per cell.
|
||||
# llama_seq_id * cells_sequences;
|
||||
# };
|
||||
class llama_kv_cache_view(ctypes.Structure):
|
||||
|
@ -1260,14 +1257,14 @@ llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view)
|
|||
|
||||
|
||||
# // Create an empty KV cache view. (use only for debugging purposes)
|
||||
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
|
||||
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
||||
@ctypes_function(
|
||||
"llama_kv_cache_view_init",
|
||||
[llama_context_p_ctypes, ctypes.c_int32],
|
||||
llama_kv_cache_view,
|
||||
)
|
||||
def llama_kv_cache_view_init(
|
||||
ctx: llama_context_p, n_max_seq: Union[ctypes.c_int32, int], /
|
||||
ctx: llama_context_p, n_seq_max: Union[ctypes.c_int32, int], /
|
||||
) -> llama_kv_cache_view:
|
||||
"""Create an empty KV cache view. (use only for debugging purposes)"""
|
||||
...
|
||||
|
@ -1582,8 +1579,7 @@ def llama_load_session_file(
|
|||
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API bool llama_save_session_file(
|
||||
|
@ -1607,8 +1603,7 @@ def llama_save_session_file(
|
|||
tokens: CtypesArray[llama_token],
|
||||
n_token_count: Union[ctypes.c_size_t, int],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
) -> int: ...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -1728,6 +1723,17 @@ def llama_set_n_threads(
|
|||
"""
|
||||
...
|
||||
|
||||
|
||||
# // Set whether to use causal attention or not
|
||||
# // If set to true, the model will only attend to the past tokens
|
||||
# LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||
@ctypes_function("llama_set_causal_attn", [llama_context_p_ctypes, ctypes.c_bool], None)
|
||||
def llama_set_causal_attn(ctx: llama_context_p, causal_attn: bool, /):
|
||||
"""Set whether to use causal attention or not
|
||||
If set to true, the model will only attend to the past tokens"""
|
||||
...
|
||||
|
||||
|
||||
# // Set abort callback
|
||||
# LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||
@ctypes_function(
|
||||
|
@ -1745,6 +1751,18 @@ def llama_set_abort_callback(
|
|||
...
|
||||
|
||||
|
||||
# // Wait until all computations are finished
|
||||
# // This is automatically done when using one of the functions below to obtain the computation results
|
||||
# // and is not necessary to call it explicitly in most cases
|
||||
# LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
||||
@ctypes_function("llama_synchronize", [llama_context_p_ctypes], None)
|
||||
def llama_synchronize(ctx: llama_context_p, /):
|
||||
"""Wait until all computations are finished
|
||||
This is automatically done when using one of the functions below to obtain the computation results
|
||||
and is not necessary to call it explicitly in most cases"""
|
||||
...
|
||||
|
||||
|
||||
# // Token logits obtained from the last call to llama_decode()
|
||||
# // The logits for the last token are stored in the last row
|
||||
# // Logits for which llama_batch.logits[i] == 0 are undefined
|
||||
|
@ -1760,7 +1778,7 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
|
|||
Logits for which llama_batch.logits[i] == 0 are undefined
|
||||
Rows: n_tokens provided with llama_batch
|
||||
Cols: n_vocab
|
||||
|
||||
|
||||
Returns:
|
||||
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
|
||||
...
|
||||
|
@ -1828,6 +1846,7 @@ def llama_get_embeddings_seq(
|
|||
shape: [n_embd] (1-dimensional)"""
|
||||
...
|
||||
|
||||
|
||||
# //
|
||||
# // Vocab
|
||||
# //
|
||||
|
@ -1839,8 +1858,7 @@ def llama_get_embeddings_seq(
|
|||
)
|
||||
def llama_token_get_text(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> bytes:
|
||||
...
|
||||
) -> bytes: ...
|
||||
|
||||
|
||||
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
||||
|
@ -1849,8 +1867,7 @@ def llama_token_get_text(
|
|||
)
|
||||
def llama_token_get_score(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> float:
|
||||
...
|
||||
) -> float: ...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
||||
|
@ -1859,8 +1876,7 @@ def llama_token_get_score(
|
|||
)
|
||||
def llama_token_get_type(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> int:
|
||||
...
|
||||
) -> int: ...
|
||||
|
||||
|
||||
# // Special tokens
|
||||
|
@ -1913,20 +1929,17 @@ def llama_token_prefix(model: llama_model_p) -> int:
|
|||
|
||||
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
||||
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_middle(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_token_middle(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
||||
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_suffix(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_token_suffix(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
|
||||
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
|
||||
def llama_token_eot(model: llama_model_p, /) -> int:
|
||||
...
|
||||
def llama_token_eot(model: llama_model_p, /) -> int: ...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -1936,7 +1949,7 @@ def llama_token_eot(model: llama_model_p, /) -> int:
|
|||
|
||||
# /// @details Convert the provided text into tokens.
|
||||
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||
# /// @return Returns the number of tokens on success, no more than n_max_tokens
|
||||
# /// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||
# /// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
||||
# /// Does not insert a leading space.
|
||||
|
@ -1945,7 +1958,7 @@ def llama_token_eot(model: llama_model_p, /) -> int:
|
|||
# const char * text,
|
||||
# int32_t text_len,
|
||||
# llama_token * tokens,
|
||||
# int32_t n_max_tokens,
|
||||
# int32_t n_tokens_max,
|
||||
# bool add_bos,
|
||||
# bool special);
|
||||
@ctypes_function(
|
||||
|
@ -1966,12 +1979,26 @@ def llama_tokenize(
|
|||
text: bytes,
|
||||
text_len: Union[ctypes.c_int, int],
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_max_tokens: Union[ctypes.c_int, int],
|
||||
n_tokens_max: Union[ctypes.c_int, int],
|
||||
add_bos: Union[ctypes.c_bool, bool],
|
||||
special: Union[ctypes.c_bool, bool],
|
||||
/,
|
||||
) -> int:
|
||||
"""Convert the provided text into tokens."""
|
||||
"""Convert the provided text into tokens.
|
||||
|
||||
Args:
|
||||
model: The model to use for tokenization.
|
||||
text: The text to tokenize.
|
||||
text_len: The length of the text.
|
||||
tokens: The tokens pointer must be large enough to hold the resulting tokens.
|
||||
n_max_tokens: The maximum number of tokens to return.
|
||||
add_bos: Whether to add a beginning-of-sentence token.
|
||||
special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
||||
Does not insert a leading space.
|
||||
|
||||
Returns:
|
||||
Returns the number of tokens on success, no more than n_tokens_max
|
||||
Returns a negative number on failure - the number of tokens that would have been returned"""
|
||||
...
|
||||
|
||||
|
||||
|
@ -2043,8 +2070,7 @@ def llama_chat_apply_template(
|
|||
chat: CtypesArray[llama_chat_message],
|
||||
n_msg: int,
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
) -> int: ...
|
||||
|
||||
|
||||
# //
|
||||
|
@ -2645,8 +2671,7 @@ def llama_beam_search(
|
|||
n_past: Union[ctypes.c_int, int],
|
||||
n_predict: Union[ctypes.c_int, int],
|
||||
/,
|
||||
):
|
||||
...
|
||||
): ...
|
||||
|
||||
|
||||
# Performance information
|
||||
|
@ -2723,5 +2748,4 @@ def llama_log_set(
|
|||
[ctypes.c_void_p, llama_context_p_ctypes],
|
||||
None,
|
||||
)
|
||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
|
||||
...
|
||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e
|
||||
Subproject commit 4e9a7f7f7fb6acbddd1462909c8d696e38edbfcc
|
Loading…
Reference in a new issue