diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index b4e3c75..dce1764 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,2 +1,2 @@ from .llama_cpp import * -from .llama import * \ No newline at end of file +from .llama import * diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 414a987..dc0f38b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -60,7 +60,11 @@ class Llama: stop = [s.encode("utf-8") for s in stop] prompt_tokens = llama_cpp.llama_tokenize( - self.ctx, prompt.encode("utf-8"), self.tokens, llama_cpp.llama_n_ctx(self.ctx), True + self.ctx, + prompt.encode("utf-8"), + self.tokens, + llama_cpp.llama_n_ctx(self.ctx), + True, ) if prompt_tokens + max_tokens > self.params.n_ctx: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 6ae8aa4..293dd0c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -67,6 +67,7 @@ def llama_context_default_params() -> llama_context_params: lib.llama_context_default_params.argtypes = [] lib.llama_context_default_params.restype = llama_context_params + # Various functions for loading a ggml llama model. # Allocate (almost) all memory needed for the model. # Return NULL on failure @@ -79,6 +80,7 @@ def llama_init_from_file( lib.llama_init_from_file.argtypes = [c_char_p, llama_context_params] lib.llama_init_from_file.restype = llama_context_p + # Frees all allocated memory def llama_free(ctx: llama_context_p): lib.llama_free(ctx) @@ -87,6 +89,7 @@ def llama_free(ctx: llama_context_p): lib.llama_free.argtypes = [llama_context_p] lib.llama_free.restype = None + # TODO: not great API - very likely to change # Returns 0 on success def llama_model_quantize( @@ -98,6 +101,7 @@ def llama_model_quantize( lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int] lib.llama_model_quantize.restype = c_int + # Run the llama inference to obtain the logits and probabilities for the next token. # tokens + n_tokens is the provided batch of new tokens to process # n_past is the number of tokens to use from previous eval calls @@ -155,6 +159,7 @@ def llama_n_ctx(ctx: llama_context_p) -> c_int: lib.llama_n_ctx.argtypes = [llama_context_p] lib.llama_n_ctx.restype = c_int + # Token logits obtained from the last call to llama_eval() # The logits for the last token are stored in the last row # Can be mutated in order to change the probabilities of the next token @@ -167,14 +172,17 @@ def llama_get_logits(ctx: llama_context_p): lib.llama_get_logits.argtypes = [llama_context_p] lib.llama_get_logits.restype = POINTER(c_float) + # Get the embeddings for the input # shape: [n_embd] (1-dimensional) def llama_get_embeddings(ctx: llama_context_p): return lib.llama_get_embeddings(ctx) + lib.llama_get_embeddings.argtypes = [llama_context_p] lib.llama_get_embeddings.restype = POINTER(c_float) + # Token Id -> String. Uses the vocabulary in the provided context def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes: return lib.llama_token_to_str(ctx, token) @@ -185,6 +193,7 @@ lib.llama_token_to_str.restype = c_char_p # Special tokens + def llama_token_bos() -> llama_token: return lib.llama_token_bos() @@ -230,6 +239,7 @@ lib.llama_sample_top_p_top_k.restype = llama_token # Performance information + def llama_print_timings(ctx: llama_context_p): lib.llama_print_timings(ctx) diff --git a/setup.py b/setup.py index b8b1b74..69b34a8 100644 --- a/setup.py +++ b/setup.py @@ -7,5 +7,5 @@ setup( author="Andrei Betlen", author_email="abetlen@gmail.com", license="MIT", - packages=["llama_cpp"] + packages=["llama_cpp"], )