Update llama.cpp

This commit is contained in:
Andrei Betlen 2024-01-03 22:04:04 -05:00
parent 011c3630f5
commit eb9c7d4ed8
2 changed files with 74 additions and 66 deletions

View file

@ -93,6 +93,9 @@ c_size_t_p = POINTER(c_size_t)
# llama.h bindings # llama.h bindings
_lib.llama_max_devices.argtypes = []
_lib.llama_max_devices.restype = ctypes.c_int32
LLAMA_MAX_DEVICES = _lib.llama_max_devices() LLAMA_MAX_DEVICES = _lib.llama_max_devices()
# define LLAMA_DEFAULT_SEED 0xFFFFFFFF # define LLAMA_DEFAULT_SEED 0xFFFFFFFF
@ -481,7 +484,7 @@ It might not exist for progress report where '.' is output repeatedly."""
# // model quantization parameters # // model quantization parameters
# typedef struct llama_model_quantize_params { # typedef struct llama_model_quantize_params {
# int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() # int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
# enum llama_ftype ftype; // quantize to this llama_ftype # enum llama_ftype ftype; // quantize to this llama_ftype
# bool allow_requantize; // allow quantizing non-f32/f16 tensors # bool allow_requantize; // allow quantizing non-f32/f16 tensors
# bool quantize_output_tensor; // quantize output.weight # bool quantize_output_tensor; // quantize output.weight
@ -499,7 +502,7 @@ class llama_model_quantize_params(Structure):
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
pure (bool): disable k-quant mixtures and quantize all tensors to the same type""" pure (bool): disable k-quant mixtures and quantize all tensors to the same type"""
_fields_ = [ _fields_ = [
("nthread", c_int), ("nthread", c_int32),
("ftype", c_int), ("ftype", c_int),
("allow_requantize", c_bool), ("allow_requantize", c_bool),
("quantize_output_tensor", c_bool), ("quantize_output_tensor", c_bool),
@ -698,13 +701,13 @@ _lib.llama_time_us.argtypes = []
_lib.llama_time_us.restype = ctypes.c_int64 _lib.llama_time_us.restype = ctypes.c_int64
# LLAMA_API int llama_max_devices (void); # LLAMA_API int32_t llama_max_devices(void);
def llama_max_devices() -> int: def llama_max_devices() -> int:
return _lib.llama_max_devices() return _lib.llama_max_devices()
_lib.llama_max_devices.argtypes = [] _lib.llama_max_devices.argtypes = []
_lib.llama_max_devices.restype = c_int _lib.llama_max_devices.restype = ctypes.c_int32
# LLAMA_API bool llama_mmap_supported (void); # LLAMA_API bool llama_mmap_supported (void);
@ -734,7 +737,7 @@ _lib.llama_get_model.argtypes = [llama_context_p]
_lib.llama_get_model.restype = llama_model_p _lib.llama_get_model.restype = llama_model_p
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
def llama_n_ctx(ctx: llama_context_p) -> int: def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx) return _lib.llama_n_ctx(ctx)
@ -758,31 +761,31 @@ _lib.llama_vocab_type.argtypes = [llama_model_p]
_lib.llama_vocab_type.restype = c_int _lib.llama_vocab_type.restype = c_int
# LLAMA_API int llama_n_vocab (const struct llama_model * model); # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
def llama_n_vocab(model: llama_model_p) -> int: def llama_n_vocab(model: llama_model_p) -> int:
return _lib.llama_n_vocab(model) return _lib.llama_n_vocab(model)
_lib.llama_n_vocab.argtypes = [llama_model_p] _lib.llama_n_vocab.argtypes = [llama_model_p]
_lib.llama_n_vocab.restype = c_int _lib.llama_n_vocab.restype = c_int32
# LLAMA_API int llama_n_ctx_train(const struct llama_model * model); # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
def llama_n_ctx_train(model: llama_model_p) -> int: def llama_n_ctx_train(model: llama_model_p) -> int:
return _lib.llama_n_ctx_train(model) return _lib.llama_n_ctx_train(model)
_lib.llama_n_ctx_train.argtypes = [llama_model_p] _lib.llama_n_ctx_train.argtypes = [llama_model_p]
_lib.llama_n_ctx_train.restype = c_int _lib.llama_n_ctx_train.restype = c_int32
# LLAMA_API int llama_n_embd (const struct llama_model * model); # LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
def llama_n_embd(model: llama_model_p) -> int: def llama_n_embd(model: llama_model_p) -> int:
return _lib.llama_n_embd(model) return _lib.llama_n_embd(model)
_lib.llama_n_embd.argtypes = [llama_model_p] _lib.llama_n_embd.argtypes = [llama_model_p]
_lib.llama_n_embd.restype = c_int _lib.llama_n_embd.restype = c_int32
# // Get the model's RoPE frequency scaling factor # // Get the model's RoPE frequency scaling factor
@ -802,7 +805,7 @@ _lib.llama_rope_freq_scale_train.restype = c_float
# // Get metadata value as a string by key name # // Get metadata value as a string by key name
# LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); # LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
def llama_model_meta_val_str( def llama_model_meta_val_str(
model: llama_model_p, key: Union[c_char_p, bytes], buf: bytes, buf_size: int model: llama_model_p, key: Union[c_char_p, bytes], buf: bytes, buf_size: int
) -> int: ) -> int:
@ -811,22 +814,22 @@ def llama_model_meta_val_str(
_lib.llama_model_meta_val_str.argtypes = [llama_model_p, c_char_p, c_char_p, c_size_t] _lib.llama_model_meta_val_str.argtypes = [llama_model_p, c_char_p, c_char_p, c_size_t]
_lib.llama_model_meta_val_str.restype = c_int _lib.llama_model_meta_val_str.restype = c_int32
# // Get the number of metadata key/value pairs # // Get the number of metadata key/value pairs
# LLAMA_API int llama_model_meta_count(const struct llama_model * model); # LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
def llama_model_meta_count(model: llama_model_p) -> int: def llama_model_meta_count(model: llama_model_p) -> int:
"""Get the number of metadata key/value pairs""" """Get the number of metadata key/value pairs"""
return _lib.llama_model_meta_count(model) return _lib.llama_model_meta_count(model)
_lib.llama_model_meta_count.argtypes = [llama_model_p] _lib.llama_model_meta_count.argtypes = [llama_model_p]
_lib.llama_model_meta_count.restype = c_int _lib.llama_model_meta_count.restype = c_int32
# // Get metadata key name by index # // Get metadata key name by index
# LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size); # LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
def llama_model_meta_key_by_index( def llama_model_meta_key_by_index(
model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int
) -> int: ) -> int:
@ -834,12 +837,17 @@ def llama_model_meta_key_by_index(
return _lib.llama_model_meta_key_by_index(model, i, buf, buf_size) return _lib.llama_model_meta_key_by_index(model, i, buf, buf_size)
_lib.llama_model_meta_key_by_index.argtypes = [llama_model_p, c_int, c_char_p, c_size_t] _lib.llama_model_meta_key_by_index.argtypes = [
_lib.llama_model_meta_key_by_index.restype = c_int llama_model_p,
c_int32,
c_char_p,
c_size_t,
]
_lib.llama_model_meta_key_by_index.restype = c_int32
# // Get metadata value as a string by index # // Get metadata value as a string by index
# LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size); # LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
def llama_model_meta_val_str_by_index( def llama_model_meta_val_str_by_index(
model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int
) -> int: ) -> int:
@ -849,15 +857,15 @@ def llama_model_meta_val_str_by_index(
_lib.llama_model_meta_val_str_by_index.argtypes = [ _lib.llama_model_meta_val_str_by_index.argtypes = [
llama_model_p, llama_model_p,
c_int, c_int32,
c_char_p, c_char_p,
c_size_t, c_size_t,
] ]
_lib.llama_model_meta_val_str_by_index.restype = c_int _lib.llama_model_meta_val_str_by_index.restype = c_int32
# // Get a string describing the model type # // Get a string describing the model type
# LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); # LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
def llama_model_desc( def llama_model_desc(
model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int] model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int]
) -> int: ) -> int:
@ -866,7 +874,7 @@ def llama_model_desc(
_lib.llama_model_desc.argtypes = [llama_model_p, c_char_p, c_size_t] _lib.llama_model_desc.argtypes = [llama_model_p, c_char_p, c_size_t]
_lib.llama_model_desc.restype = c_int _lib.llama_model_desc.restype = c_int32
# // Returns the total size of all the tensors in the model in bytes # // Returns the total size of all the tensors in the model in bytes
@ -905,7 +913,7 @@ _lib.llama_get_model_tensor.restype = c_void_p
# // Returns 0 on success # // Returns 0 on success
# LLAMA_API int llama_model_quantize( # LLAMA_API uint32_t llama_model_quantize(
# const char * fname_inp, # const char * fname_inp,
# const char * fname_out, # const char * fname_out,
# const llama_model_quantize_params * params); # const llama_model_quantize_params * params);
@ -923,7 +931,7 @@ _lib.llama_model_quantize.argtypes = [
c_char_p, c_char_p,
POINTER(llama_model_quantize_params), POINTER(llama_model_quantize_params),
] ]
_lib.llama_model_quantize.restype = c_int _lib.llama_model_quantize.restype = c_uint32
# // Apply a LoRA adapter to a loaded model # // Apply a LoRA adapter to a loaded model
@ -932,12 +940,12 @@ _lib.llama_model_quantize.restype = c_int
# // The model needs to be reloaded before applying a new adapter, otherwise the adapter # // The model needs to be reloaded before applying a new adapter, otherwise the adapter
# // will be applied on top of the previous one # // will be applied on top of the previous one
# // Returns 0 on success # // Returns 0 on success
# LLAMA_API DEPRECATED(int llama_apply_lora_from_file( # LLAMA_API DEPRECATED(int32_t llama_apply_lora_from_file(
# struct llama_context * ctx, # struct llama_context * ctx,
# const char * path_lora, # const char * path_lora,
# float scale, # float scale,
# const char * path_base_model, # const char * path_base_model,
# int n_threads), # int32_t n_threads),
# "use llama_model_apply_lora_from_file instead"); # "use llama_model_apply_lora_from_file instead");
def llama_apply_lora_from_file( def llama_apply_lora_from_file(
ctx: llama_context_p, ctx: llama_context_p,
@ -962,17 +970,17 @@ _lib.llama_apply_lora_from_file.argtypes = [
c_char_p, c_char_p,
c_float, c_float,
c_char_p, c_char_p,
c_int, c_int32,
] ]
_lib.llama_apply_lora_from_file.restype = c_int _lib.llama_apply_lora_from_file.restype = c_int32
# LLAMA_API int llama_model_apply_lora_from_file( # LLAMA_API int32_t llama_model_apply_lora_from_file(
# const struct llama_model * model, # const struct llama_model * model,
# const char * path_lora, # const char * path_lora,
# float scale, # float scale,
# const char * path_base_model, # const char * path_base_model,
# int n_threads); # int32_t n_threads);
def llama_model_apply_lora_from_file( def llama_model_apply_lora_from_file(
model: llama_model_p, model: llama_model_p,
path_lora: Union[c_char_p, bytes], path_lora: Union[c_char_p, bytes],
@ -990,9 +998,9 @@ _lib.llama_model_apply_lora_from_file.argtypes = [
c_char_p, c_char_p,
c_float, c_float,
c_char_p, c_char_p,
c_int, c_int32,
] ]
_lib.llama_model_apply_lora_from_file.restype = c_int _lib.llama_model_apply_lora_from_file.restype = c_int32
# // # //
# // KV cache # // KV cache
@ -1094,7 +1102,7 @@ _lib.llama_kv_cache_view_update.restype = None
# // Returns the number of tokens in the KV cache (slow, use only for debug) # // Returns the number of tokens in the KV cache (slow, use only for debug)
# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times # // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
# LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); # LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
"""Returns the number of tokens in the KV cache (slow, use only for debug) """Returns the number of tokens in the KV cache (slow, use only for debug)
If a KV cell has multiple sequences assigned to it, it will be counted multiple times If a KV cell has multiple sequences assigned to it, it will be counted multiple times
@ -1103,18 +1111,18 @@ def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int _lib.llama_get_kv_cache_token_count.restype = c_int32
# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) # // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
# LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx); # LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
def llama_get_kv_cache_used_cells(ctx: llama_context_p) -> int: def llama_get_kv_cache_used_cells(ctx: llama_context_p) -> int:
"""Returns the number of used KV cells (i.e. have at least one sequence assigned to them)""" """Returns the number of used KV cells (i.e. have at least one sequence assigned to them)"""
return _lib.llama_get_kv_cache_used_cells(ctx) return _lib.llama_get_kv_cache_used_cells(ctx)
_lib.llama_get_kv_cache_used_cells.argtypes = [llama_context_p] _lib.llama_get_kv_cache_used_cells.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_used_cells.restype = c_int _lib.llama_get_kv_cache_used_cells.restype = c_int32
# // Clear the KV cache # // Clear the KV cache
@ -1361,7 +1369,7 @@ _lib.llama_save_session_file.restype = c_size_t
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_token * tokens, # llama_token * tokens,
# int32_t n_tokens, # int32_t n_tokens,
# int n_past), # int32_t n_past),
# "use llama_decode() instead"); # "use llama_decode() instead");
def llama_eval( def llama_eval(
ctx: llama_context_p, ctx: llama_context_p,
@ -1377,7 +1385,7 @@ def llama_eval(
return _lib.llama_eval(ctx, tokens, n_tokens, n_past) return _lib.llama_eval(ctx, tokens, n_tokens, n_past)
_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int, c_int] _lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int32, c_int32]
_lib.llama_eval.restype = c_int _lib.llama_eval.restype = c_int
@ -1387,7 +1395,7 @@ _lib.llama_eval.restype = c_int
# struct llama_context * ctx, # struct llama_context * ctx,
# float * embd, # float * embd,
# int32_t n_tokens, # int32_t n_tokens,
# int n_past), # int32_t n_past),
# "use llama_decode() instead"); # "use llama_decode() instead");
def llama_eval_embd( def llama_eval_embd(
ctx: llama_context_p, ctx: llama_context_p,
@ -1400,7 +1408,7 @@ def llama_eval_embd(
return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past) return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past)
_lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int, c_int] _lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int32, c_int32]
_lib.llama_eval_embd.restype = c_int _lib.llama_eval_embd.restype = c_int
@ -1480,7 +1488,7 @@ _lib.llama_batch_free.restype = None
# // 0 - success # // 0 - success
# // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) # // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
# // < 0 - error # // < 0 - error
# LLAMA_API int llama_decode( # LLAMA_API int32_t llama_decode(
# struct llama_context * ctx, # struct llama_context * ctx,
# struct llama_batch batch); # struct llama_batch batch);
def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int: def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int:
@ -1492,7 +1500,7 @@ def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int:
_lib.llama_decode.argtypes = [llama_context_p, llama_batch] _lib.llama_decode.argtypes = [llama_context_p, llama_batch]
_lib.llama_decode.restype = c_int _lib.llama_decode.restype = c_int32
# // Set the number of threads used for decoding # // Set the number of threads used for decoding
@ -1634,25 +1642,25 @@ _lib.llama_token_nl.restype = llama_token
# // Returns -1 if unknown, 1 for true or 0 for false. # // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int llama_add_bos_token(const struct llama_model * model); # LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
def llama_add_bos_token(model: llama_model_p) -> int: def llama_add_bos_token(model: llama_model_p) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false.""" """Returns -1 if unknown, 1 for true or 0 for false."""
return _lib.llama_add_bos_token(model) return _lib.llama_add_bos_token(model)
_lib.llama_add_bos_token.argtypes = [llama_model_p] _lib.llama_add_bos_token.argtypes = [llama_model_p]
_lib.llama_add_bos_token.restype = c_int _lib.llama_add_bos_token.restype = c_int32
# // Returns -1 if unknown, 1 for true or 0 for false. # // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int llama_add_eos_token(const struct llama_model * model); # LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
def llama_add_eos_token(model: llama_model_p) -> int: def llama_add_eos_token(model: llama_model_p) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false.""" """Returns -1 if unknown, 1 for true or 0 for false."""
return _lib.llama_add_eos_token(model) return _lib.llama_add_eos_token(model)
_lib.llama_add_eos_token.argtypes = [llama_model_p] _lib.llama_add_eos_token.argtypes = [llama_model_p]
_lib.llama_add_eos_token.restype = c_int _lib.llama_add_eos_token.restype = c_int32
# // codellama infill tokens # // codellama infill tokens
@ -1704,12 +1712,12 @@ _lib.llama_token_eot.restype = llama_token
# /// @return Returns a negative number on failure - the number of tokens that would have been returned # /// @return Returns a negative number on failure - the number of tokens that would have been returned
# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. # /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
# /// Does not insert a leading space. # /// Does not insert a leading space.
# LLAMA_API int llama_tokenize( # LLAMA_API int32_t llama_tokenize(
# const struct llama_model * model, # const struct llama_model * model,
# const char * text, # const char * text,
# int text_len, # int32_t text_len,
# llama_token * tokens, # llama_token * tokens,
# int n_max_tokens, # int32_t n_max_tokens,
# bool add_bos, # bool add_bos,
# bool special); # bool special);
def llama_tokenize( def llama_tokenize(
@ -1730,24 +1738,24 @@ def llama_tokenize(
_lib.llama_tokenize.argtypes = [ _lib.llama_tokenize.argtypes = [
llama_model_p, llama_model_p,
c_char_p, c_char_p,
c_int, c_int32,
llama_token_p, llama_token_p,
c_int, c_int32,
c_bool, c_bool,
c_bool, c_bool,
] ]
_lib.llama_tokenize.restype = c_int _lib.llama_tokenize.restype = c_int32
# // Token Id -> Piece. # // Token Id -> Piece.
# // Uses the vocabulary in the provided context. # // Uses the vocabulary in the provided context.
# // Does not write null terminator to the buffer. # // Does not write null terminator to the buffer.
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. # // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
# LLAMA_API int llama_token_to_piece( # LLAMA_API int32_t llama_token_to_piece(
# const struct llama_model * model, # const struct llama_model * model,
# llama_token token, # llama_token token,
# char * buf, # char * buf,
# int length); # int32_t length);
def llama_token_to_piece( def llama_token_to_piece(
model: llama_model_p, model: llama_model_p,
token: Union[llama_token, int], token: Union[llama_token, int],
@ -1762,8 +1770,8 @@ def llama_token_to_piece(
return _lib.llama_token_to_piece(model, token, buf, length) return _lib.llama_token_to_piece(model, token, buf, length)
_lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_int] _lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_int32]
_lib.llama_token_to_piece.restype = c_int _lib.llama_token_to_piece.restype = c_int32
# // # //
@ -1924,7 +1932,7 @@ _lib.llama_sample_softmax.restype = None
# LLAMA_API void llama_sample_top_k( # LLAMA_API void llama_sample_top_k(
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_token_data_array * candidates, # llama_token_data_array * candidates,
# int k, # int32_t k,
# size_t min_keep); # size_t min_keep);
def llama_sample_top_k( def llama_sample_top_k(
ctx: llama_context_p, ctx: llama_context_p,
@ -1939,7 +1947,7 @@ def llama_sample_top_k(
_lib.llama_sample_top_k.argtypes = [ _lib.llama_sample_top_k.argtypes = [
llama_context_p, llama_context_p,
llama_token_data_array_p, llama_token_data_array_p,
c_int, c_int32,
c_size_t, c_size_t,
] ]
_lib.llama_sample_top_k.restype = None _lib.llama_sample_top_k.restype = None
@ -2129,7 +2137,7 @@ _lib.llama_sample_grammar.restype = None
# llama_token_data_array * candidates, # llama_token_data_array * candidates,
# float tau, # float tau,
# float eta, # float eta,
# int m, # int32_t m,
# float * mu); # float * mu);
def llama_sample_token_mirostat( def llama_sample_token_mirostat(
ctx: llama_context_p, ctx: llama_context_p,
@ -2155,7 +2163,7 @@ _lib.llama_sample_token_mirostat.argtypes = [
llama_token_data_array_p, llama_token_data_array_p,
c_float, c_float,
c_float, c_float,
c_int, c_int32,
c_float_p, c_float_p,
] ]
_lib.llama_sample_token_mirostat.restype = llama_token _lib.llama_sample_token_mirostat.restype = llama_token
@ -2320,8 +2328,8 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_s
# llama_beam_search_callback_fn_t callback, # llama_beam_search_callback_fn_t callback,
# void * callback_data, # void * callback_data,
# size_t n_beams, # size_t n_beams,
# int n_past, # int32_t n_past,
# int n_predict); # int32_t n_predict);
def llama_beam_search( def llama_beam_search(
ctx: llama_context_p, ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
@ -2340,8 +2348,8 @@ _lib.llama_beam_search.argtypes = [
llama_beam_search_callback_fn_t, llama_beam_search_callback_fn_t,
c_void_p, c_void_p,
c_size_t, c_size_t,
c_int, c_int32,
c_int, c_int32,
] ]
_lib.llama_beam_search.restype = None _lib.llama_beam_search.restype = None

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit f6793491b5af6da75edad34d6f503ef86d31b09f Subproject commit cb1e2818e0e12ec99f7236ec5d4f3ffd8bcc2f4a