Update llama.cpp

This commit is contained in:
Andrei Betlen 2023-09-09 12:12:32 -04:00
parent a7fb07ac77
commit d3f63211ef
2 changed files with 100 additions and 73 deletions

View file

@ -506,7 +506,7 @@ _lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
# LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx)
@ -524,6 +524,15 @@ _lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int
# LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
def llama_n_ctx_train(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx_train(ctx)
_lib.llama_n_ctx_train.argtypes = [llama_context_p]
_lib.llama_n_ctx_train.restype = c_int
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx)
@ -542,7 +551,7 @@ _lib.llama_vocab_type.argtypes = [llama_context_p]
_lib.llama_vocab_type.restype = c_int
# LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
# LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
def llama_model_n_vocab(model: llama_model_p) -> int:
return _lib.llama_model_n_vocab(model)
@ -560,6 +569,15 @@ _lib.llama_model_n_ctx.argtypes = [llama_model_p]
_lib.llama_model_n_ctx.restype = c_int
# LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
def llama_model_n_ctx_train(model: llama_model_p) -> int:
return _lib.llama_model_n_ctx_train(model)
_lib.llama_model_n_ctx_train.argtypes = [llama_model_p]
_lib.llama_model_n_ctx_train.restype = c_int
# LLAMA_API int llama_model_n_embd (const struct llama_model * model);
def llama_model_n_embd(model: llama_model_p) -> int:
return _lib.llama_model_n_embd(model)
@ -1046,74 +1064,14 @@ def llama_grammar_free(grammar: llama_grammar_p):
_lib.llama_grammar_free.argtypes = [llama_grammar_p]
_lib.llama_grammar_free.restype = None
# //
# // Beam search
# //
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
return _lib.llama_grammar_copy(grammar)
# struct llama_beam_view {
# const llama_token * tokens;
# size_t n_tokens;
# float p; // Cumulative beam probability (renormalized relative to all beams)
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
# };
class llama_beam_view(ctypes.Structure):
_fields_ = [
("tokens", llama_token_p),
("n_tokens", c_size_t),
("p", c_float),
("eob", c_bool),
]
# // Passed to beam_search_callback function.
# // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
# // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
# // These pointers are valid only during the synchronous callback, so should not be saved.
# struct llama_beams_state {
# struct llama_beam_view * beam_views;
# size_t n_beams; // Number of elements in beam_views[].
# size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
# bool last_call; // True iff this is the last callback invocation.
# };
class llama_beams_state(ctypes.Structure):
_fields_ = [
("beam_views", POINTER(llama_beam_view)),
("n_beams", c_size_t),
("common_prefix_length", c_size_t),
("last_call", c_bool),
]
# // Type of pointer to the beam_search_callback function.
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
# /// @details Deterministically returns entire sentence constructed by a beam search.
# /// @param ctx Pointer to the llama_context.
# /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
# /// @param callback_data A pointer that is simply passed back to callback.
# /// @param n_beams Number of beams to use.
# /// @param n_past Number of tokens already evaluated.
# /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
# /// @param n_threads Number of threads as passed to llama_eval().
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
def llama_beam_search(
ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
callback_data: c_void_p,
n_beams: c_size_t,
n_past: c_int,
n_predict: c_int,
n_threads: c_int,
):
return _lib.llama_beam_search(
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
)
_lib.llama_grammar_copy.argtypes = [llama_grammar_p]
_lib.llama_grammar_copy.restype = llama_grammar_p
# //
# // Sampling functions
@ -1436,6 +1394,74 @@ _lib.llama_grammar_accept_token.argtypes = [
llama_token,
]
_lib.llama_grammar_accept_token.restype = None
# //
# // Beam search
# //
# struct llama_beam_view {
# const llama_token * tokens;
# size_t n_tokens;
# float p; // Cumulative beam probability (renormalized relative to all beams)
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
# };
class llama_beam_view(ctypes.Structure):
_fields_ = [
("tokens", llama_token_p),
("n_tokens", c_size_t),
("p", c_float),
("eob", c_bool),
]
# // Passed to beam_search_callback function.
# // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
# // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
# // These pointers are valid only during the synchronous callback, so should not be saved.
# struct llama_beams_state {
# struct llama_beam_view * beam_views;
# size_t n_beams; // Number of elements in beam_views[].
# size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
# bool last_call; // True iff this is the last callback invocation.
# };
class llama_beams_state(ctypes.Structure):
_fields_ = [
("beam_views", POINTER(llama_beam_view)),
("n_beams", c_size_t),
("common_prefix_length", c_size_t),
("last_call", c_bool),
]
# // Type of pointer to the beam_search_callback function.
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
# /// @details Deterministically returns entire sentence constructed by a beam search.
# /// @param ctx Pointer to the llama_context.
# /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
# /// @param callback_data A pointer that is simply passed back to callback.
# /// @param n_beams Number of beams to use.
# /// @param n_past Number of tokens already evaluated.
# /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
# /// @param n_threads Number of threads as passed to llama_eval().
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
def llama_beam_search(
ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
callback_data: c_void_p,
n_beams: c_size_t,
n_past: c_int,
n_predict: c_int,
n_threads: c_int,
):
return _lib.llama_beam_search(
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
)
# Performance information
@ -1494,6 +1520,7 @@ _lib.llama_log_set.restype = None
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p):
return _lib.llama_dump_timing_info_yaml(stream, ctx)
_lib.llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p]
_lib.llama_dump_timing_info_yaml.restype = None

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 69fdbb9abc8907dd2a9ffdd840cba92d678a660a
Subproject commit 21ac3a1503001020122db5dce6adf34b761675f5