Merge branch 'main' into v0.2-wip
This commit is contained in:
commit
1910793f56
11 changed files with 253 additions and 146 deletions
13
README.md
13
README.md
|
@ -21,7 +21,7 @@ Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest
|
|||
> Starting with version 0.1.79 the model format has changed from `ggmlv3` to `gguf`. Old model files can be converted using the `convert-llama-ggmlv3-to-gguf.py` script in [`llama.cpp`](https://github.com/ggerganov/llama.cpp)
|
||||
|
||||
|
||||
## Installation from PyPI (recommended)
|
||||
## Installation from PyPI
|
||||
|
||||
Install from PyPI (requires a c compiler):
|
||||
|
||||
|
@ -45,7 +45,7 @@ bash Miniforge3-MacOSX-arm64.sh
|
|||
```
|
||||
Otherwise, while installing it will build the llama.ccp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
|
||||
|
||||
### Installation with OpenBLAS / cuBLAS / CLBlast / Metal
|
||||
### Installation with Hardware Acceleration
|
||||
|
||||
`llama.cpp` supports multiple BLAS backends for faster processing.
|
||||
Use the `FORCE_CMAKE=1` environment variable to force the use of `cmake` and install the pip package for the desired BLAS backend.
|
||||
|
@ -74,6 +74,12 @@ To install with Metal (MPS), set the `LLAMA_METAL=on` environment variable befor
|
|||
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python
|
||||
```
|
||||
|
||||
To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on` environment variable before installing:
|
||||
|
||||
```bash
|
||||
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python
|
||||
```
|
||||
|
||||
#### Windows remarks
|
||||
|
||||
To set the variables `CMAKE_ARGS` and `FORCE_CMAKE` in PowerShell, follow the next steps (Example using, OpenBLAS):
|
||||
|
@ -181,7 +187,8 @@ Below is a short example demonstrating how to use the low-level API to tokenize
|
|||
>>> import ctypes
|
||||
>>> params = llama_cpp.llama_context_default_params()
|
||||
# use bytes for char * params
|
||||
>>> ctx = llama_cpp.llama_init_from_file(b"./models/7b/ggml-model.bin", params)
|
||||
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/ggml-model.bin", params)
|
||||
>>> ctx = llama_cpp.llama_new_context_with_model(model, params)
|
||||
>>> max_tokens = params.n_ctx
|
||||
# use ctypes arrays for array params
|
||||
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()
|
||||
|
|
|
@ -24,6 +24,10 @@ class LLaMAInteract:
|
|||
def __init__(self, params: GptParams) -> None:
|
||||
# input args
|
||||
self.params = params
|
||||
if self.params.path_session is None:
|
||||
self.params.path_session = ""
|
||||
if self.params.antiprompt is None:
|
||||
self.params.antiprompt = ""
|
||||
|
||||
if (self.params.perplexity):
|
||||
raise NotImplementedError("""************
|
||||
|
@ -66,7 +70,9 @@ specified) expect poor results""", file=sys.stderr)
|
|||
self.lparams.use_mlock = self.params.use_mlock
|
||||
self.lparams.use_mmap = self.params.use_mmap
|
||||
|
||||
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.params.model.encode("utf8"), self.lparams)
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
|
||||
if (not self.ctx):
|
||||
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
||||
|
||||
|
@ -181,12 +187,12 @@ prompt: '{self.params.prompt}'
|
|||
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
|
||||
|
||||
for i in range(len(self.embd_inp)):
|
||||
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
|
||||
print(f"{self.embd_inp[i]} -> '{self.token_to_str(self.embd_inp[i])}'", file=sys.stderr)
|
||||
|
||||
if (self.params.n_keep > 0):
|
||||
print("static prompt based on n_keep: '")
|
||||
for i in range(self.params.n_keep):
|
||||
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
|
||||
print(self.token_to_str(self.embd_inp[i]), file=sys.stderr)
|
||||
print("'", file=sys.stderr)
|
||||
print(file=sys.stderr)
|
||||
|
||||
|
@ -339,7 +345,7 @@ n_keep = {self.params.n_keep}
|
|||
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||
|
||||
# Apply penalties
|
||||
nl_logit = logits[llama_cpp.llama_token_nl()]
|
||||
nl_logit = logits[llama_cpp.llama_token_nl(self.ctx)]
|
||||
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
||||
|
||||
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
||||
|
@ -380,7 +386,7 @@ n_keep = {self.params.n_keep}
|
|||
self.last_n_tokens.append(id)
|
||||
|
||||
# replace end of text token with newline token when in interactive mode
|
||||
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
|
||||
if (id == llama_cpp.llama_token_eos(self.ctx) and self.params.interactive and not self.params.instruct):
|
||||
id = self.llama_token_newline[0]
|
||||
self.embd.append(id)
|
||||
if (self.use_antiprompt()):
|
||||
|
@ -437,7 +443,7 @@ n_keep = {self.params.n_keep}
|
|||
break
|
||||
|
||||
# end of text token
|
||||
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
|
||||
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(self.ctx):
|
||||
if (not self.params.instruct):
|
||||
for i in self.llama_token_eot:
|
||||
yield i
|
||||
|
@ -464,10 +470,18 @@ n_keep = {self.params.n_keep}
|
|||
llama_cpp.llama_free(self.ctx)
|
||||
self.set_color(util.CONSOLE_COLOR_DEFAULT)
|
||||
|
||||
def token_to_str(self, token_id: int) -> bytes:
|
||||
size = 32
|
||||
buffer = (ctypes.c_char * size)()
|
||||
n = llama_cpp.llama_token_to_piece_with_model(
|
||||
self.model, llama_cpp.llama_token(token_id), buffer, size)
|
||||
assert n <= size
|
||||
return bytes(buffer[:n])
|
||||
|
||||
# return past text
|
||||
def past(self):
|
||||
for id in self.last_n_tokens[-self.n_past:]:
|
||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
|
||||
yield self.token_to_str(id).decode("utf8", errors="ignore")
|
||||
|
||||
# write input
|
||||
def input(self, prompt: str):
|
||||
|
@ -481,7 +495,7 @@ n_keep = {self.params.n_keep}
|
|||
def output(self):
|
||||
self.remaining_tokens = self.params.n_predict
|
||||
for id in self.generate():
|
||||
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
|
||||
cur_char = self.token_to_str(id)
|
||||
|
||||
# Add remainder of missing bytes
|
||||
if None in self.multibyte_fix:
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
import llama_cpp
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
import llama_cpp
|
||||
|
||||
N_THREADS = multiprocessing.cpu_count()
|
||||
MODEL_PATH = os.environ.get('MODEL', "../models/7B/ggml-model.bin")
|
||||
|
||||
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
|
||||
|
||||
lparams = llama_cpp.llama_context_default_params()
|
||||
ctx = llama_cpp.llama_init_from_file(b"../models/7B/ggml-model.bin", lparams)
|
||||
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode('utf-8'), lparams)
|
||||
ctx = llama_cpp.llama_new_context_with_model(model, lparams)
|
||||
|
||||
# determine the required inference memory per token:
|
||||
tmp = [0, 1, 2, 3]
|
||||
|
@ -58,7 +60,8 @@ while remaining_tokens > 0:
|
|||
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||
for token_id in range(n_vocab)
|
||||
])
|
||||
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||
candidates_p = llama_cpp.ctypes.pointer(
|
||||
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||
|
||||
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
||||
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
||||
|
@ -68,9 +71,9 @@ while remaining_tokens > 0:
|
|||
_arr,
|
||||
last_n_repeat, frequency_penalty, presence_penalty)
|
||||
|
||||
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
|
||||
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
|
||||
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
|
||||
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
|
||||
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
|
||||
llama_cpp.llama_sample_temperature(ctx, candidates_p, temp=0.2)
|
||||
id = llama_cpp.llama_sample_token(ctx, candidates_p)
|
||||
|
||||
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
||||
|
@ -86,13 +89,18 @@ while remaining_tokens > 0:
|
|||
break
|
||||
if not input_noecho:
|
||||
for id in embd:
|
||||
size = 32
|
||||
buffer = (ctypes.c_char * size)()
|
||||
n = llama_cpp.llama_token_to_piece_with_model(
|
||||
model, llama_cpp.llama_token(id), buffer, size)
|
||||
assert n <= size
|
||||
print(
|
||||
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
|
||||
buffer[:n].decode('utf-8'),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos():
|
||||
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(ctx):
|
||||
break
|
||||
|
||||
print()
|
||||
|
|
|
@ -1,2 +1,4 @@
|
|||
from .llama_cpp import *
|
||||
from .llama import *
|
||||
|
||||
from .version import __version__
|
|
@ -452,10 +452,10 @@ class Llama:
|
|||
"""
|
||||
assert self.model is not None
|
||||
output = b""
|
||||
size = 8
|
||||
size = 32
|
||||
buffer = (ctypes.c_char * size)()
|
||||
for token in tokens:
|
||||
n = llama_cpp.llama_token_to_str_with_model(
|
||||
n = llama_cpp.llama_token_to_piece_with_model(
|
||||
self.model, llama_cpp.llama_token(token), buffer, size
|
||||
)
|
||||
assert n <= size
|
||||
|
@ -1007,13 +1007,15 @@ class Llama:
|
|||
break
|
||||
|
||||
token_end_position = 0
|
||||
for token in remaining_tokens:
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position >= (remaining_length - first_stop_position):
|
||||
break
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
|
||||
if logprobs is not None:
|
||||
# not sure how to handle this branch when dealing
|
||||
# with CJK output, so keep it unchanged
|
||||
for token in remaining_tokens:
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position > (remaining_length - first_stop_position):
|
||||
break
|
||||
token_str = self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
@ -1046,23 +1048,59 @@ class Llama:
|
|||
"token_logprobs": [current_logprobs[int(token)]],
|
||||
"top_logprobs": [top_logprob],
|
||||
}
|
||||
returned_tokens += 1
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
returned_tokens += 1
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
while len(remaining_tokens) > 0:
|
||||
decode_success = False
|
||||
for i in range(1, len(remaining_tokens) + 1):
|
||||
try:
|
||||
bs = self.detokenize(remaining_tokens[:i])
|
||||
ts = bs.decode('utf-8')
|
||||
decode_success = True
|
||||
break
|
||||
except UnicodeError:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
if not decode_success:
|
||||
# all remaining tokens cannot be decoded to a UTF-8 character
|
||||
break
|
||||
token_end_position += len(bs)
|
||||
if token_end_position > (remaining_length - first_stop_position):
|
||||
break
|
||||
remaining_tokens = remaining_tokens[i:]
|
||||
returned_tokens += i
|
||||
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": ts,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
|
|
@ -294,6 +294,7 @@ llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p)
|
|||
# enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
# bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
# bool quantize_output_tensor; // quantize output.weight
|
||||
# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
# } llama_model_quantize_params;
|
||||
class llama_model_quantize_params(Structure):
|
||||
_fields_ = [
|
||||
|
@ -301,6 +302,7 @@ class llama_model_quantize_params(Structure):
|
|||
("ftype", c_int),
|
||||
("allow_requantize", c_bool),
|
||||
("quantize_output_tensor", c_bool),
|
||||
("only_copy", c_bool),
|
||||
]
|
||||
|
||||
|
||||
|
@ -504,7 +506,7 @@ _lib.llama_mlock_supported.argtypes = []
|
|||
_lib.llama_mlock_supported.restype = c_bool
|
||||
|
||||
|
||||
# LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
||||
# LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
|
||||
def llama_n_vocab(ctx: llama_context_p) -> int:
|
||||
return _lib.llama_n_vocab(ctx)
|
||||
|
||||
|
@ -513,7 +515,7 @@ _lib.llama_n_vocab.argtypes = [llama_context_p]
|
|||
_lib.llama_n_vocab.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
# LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
def llama_n_ctx(ctx: llama_context_p) -> int:
|
||||
return _lib.llama_n_ctx(ctx)
|
||||
|
||||
|
@ -522,7 +524,16 @@ _lib.llama_n_ctx.argtypes = [llama_context_p]
|
|||
_lib.llama_n_ctx.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
# LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
|
||||
def llama_n_ctx_train(ctx: llama_context_p) -> int:
|
||||
return _lib.llama_n_ctx_train(ctx)
|
||||
|
||||
|
||||
_lib.llama_n_ctx_train.argtypes = [llama_context_p]
|
||||
_lib.llama_n_ctx_train.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
def llama_n_embd(ctx: llama_context_p) -> int:
|
||||
return _lib.llama_n_embd(ctx)
|
||||
|
||||
|
@ -540,7 +551,7 @@ _lib.llama_vocab_type.argtypes = [llama_context_p]
|
|||
_lib.llama_vocab_type.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
|
||||
# LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
|
||||
def llama_model_n_vocab(model: llama_model_p) -> int:
|
||||
return _lib.llama_model_n_vocab(model)
|
||||
|
||||
|
@ -549,7 +560,7 @@ _lib.llama_model_n_vocab.argtypes = [llama_model_p]
|
|||
_lib.llama_model_n_vocab.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||
# LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||
def llama_model_n_ctx(model: llama_model_p) -> int:
|
||||
return _lib.llama_model_n_ctx(model)
|
||||
|
||||
|
@ -558,7 +569,16 @@ _lib.llama_model_n_ctx.argtypes = [llama_model_p]
|
|||
_lib.llama_model_n_ctx.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||
# LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
|
||||
def llama_model_n_ctx_train(model: llama_model_p) -> int:
|
||||
return _lib.llama_model_n_ctx_train(model)
|
||||
|
||||
|
||||
_lib.llama_model_n_ctx_train.argtypes = [llama_model_p]
|
||||
_lib.llama_model_n_ctx_train.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||
def llama_model_n_embd(model: llama_model_p) -> int:
|
||||
return _lib.llama_model_n_embd(model)
|
||||
|
||||
|
@ -973,48 +993,43 @@ _lib.llama_tokenize_with_model.argtypes = [
|
|||
_lib.llama_tokenize_with_model.restype = c_int
|
||||
|
||||
|
||||
# // Token Id -> String. Uses the vocabulary in the provided context
|
||||
# // Does not write null terminator to the buffer
|
||||
# LLAMA_API int llama_token_to_str(
|
||||
# // Token Id -> Piece.
|
||||
# // Uses the vocabulary in the provided context.
|
||||
# // Does not write null terminator to the buffer.
|
||||
# // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||
# LLAMA_API int llama_token_to_piece(
|
||||
# const struct llama_context * ctx,
|
||||
# llama_token token,
|
||||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_str(
|
||||
# llama_token token,
|
||||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_piece(
|
||||
ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int
|
||||
) -> int:
|
||||
return _lib.llama_token_to_str(ctx, token, buf, length)
|
||||
return _lib.llama_token_to_piece(ctx, token, buf, length)
|
||||
|
||||
|
||||
_lib.llama_tokenize_with_model.argtypes = [
|
||||
llama_model_p,
|
||||
c_char_p,
|
||||
llama_token_p,
|
||||
c_int,
|
||||
c_bool,
|
||||
]
|
||||
_lib.llama_tokenize_with_model.restype = c_int
|
||||
_lib.llama_token_to_piece.argtypes = [llama_context_p, llama_token, c_char_p, c_int]
|
||||
_lib.llama_token_to_piece.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API int llama_token_to_str_with_model(
|
||||
# const struct llama_model * model,
|
||||
# llama_token token,
|
||||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_str_with_model(
|
||||
# LLAMA_API int llama_token_to_piece_with_model(
|
||||
# const struct llama_model * model,
|
||||
# llama_token token,
|
||||
# char * buf,
|
||||
# int length);
|
||||
def llama_token_to_piece_with_model(
|
||||
model: llama_model_p, token: llama_token, buf: bytes, length: c_int
|
||||
) -> int:
|
||||
return _lib.llama_token_to_str_with_model(model, token, buf, length)
|
||||
return _lib.llama_token_to_piece_with_model(model, token, buf, length)
|
||||
|
||||
|
||||
_lib.llama_token_to_str_with_model.argtypes = [
|
||||
_lib.llama_token_to_piece_with_model.argtypes = [
|
||||
llama_model_p,
|
||||
llama_token,
|
||||
c_char_p,
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_token_to_str_with_model.restype = c_int
|
||||
|
||||
_lib.llama_token_to_piece_with_model.restype = c_int
|
||||
|
||||
# //
|
||||
# // Grammar
|
||||
|
@ -1049,74 +1064,14 @@ def llama_grammar_free(grammar: llama_grammar_p):
|
|||
_lib.llama_grammar_free.argtypes = [llama_grammar_p]
|
||||
_lib.llama_grammar_free.restype = None
|
||||
|
||||
# //
|
||||
# // Beam search
|
||||
# //
|
||||
|
||||
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||
def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
|
||||
return _lib.llama_grammar_copy(grammar)
|
||||
|
||||
|
||||
# struct llama_beam_view {
|
||||
# const llama_token * tokens;
|
||||
# size_t n_tokens;
|
||||
# float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||
# };
|
||||
class llama_beam_view(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("tokens", llama_token_p),
|
||||
("n_tokens", c_size_t),
|
||||
("p", c_float),
|
||||
("eob", c_bool),
|
||||
]
|
||||
|
||||
|
||||
# // Passed to beam_search_callback function.
|
||||
# // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||
# // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||
# // These pointers are valid only during the synchronous callback, so should not be saved.
|
||||
# struct llama_beams_state {
|
||||
# struct llama_beam_view * beam_views;
|
||||
# size_t n_beams; // Number of elements in beam_views[].
|
||||
# size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
|
||||
# bool last_call; // True iff this is the last callback invocation.
|
||||
# };
|
||||
class llama_beams_state(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("beam_views", POINTER(llama_beam_view)),
|
||||
("n_beams", c_size_t),
|
||||
("common_prefix_length", c_size_t),
|
||||
("last_call", c_bool),
|
||||
]
|
||||
|
||||
|
||||
# // Type of pointer to the beam_search_callback function.
|
||||
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
|
||||
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
|
||||
|
||||
|
||||
# /// @details Deterministically returns entire sentence constructed by a beam search.
|
||||
# /// @param ctx Pointer to the llama_context.
|
||||
# /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
|
||||
# /// @param callback_data A pointer that is simply passed back to callback.
|
||||
# /// @param n_beams Number of beams to use.
|
||||
# /// @param n_past Number of tokens already evaluated.
|
||||
# /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||
# /// @param n_threads Number of threads as passed to llama_eval().
|
||||
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||
def llama_beam_search(
|
||||
ctx: llama_context_p,
|
||||
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
|
||||
callback_data: c_void_p,
|
||||
n_beams: c_size_t,
|
||||
n_past: c_int,
|
||||
n_predict: c_int,
|
||||
n_threads: c_int,
|
||||
):
|
||||
return _lib.llama_beam_search(
|
||||
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
|
||||
)
|
||||
|
||||
_lib.llama_grammar_copy.argtypes = [llama_grammar_p]
|
||||
_lib.llama_grammar_copy.restype = llama_grammar_p
|
||||
|
||||
# //
|
||||
# // Sampling functions
|
||||
|
@ -1439,6 +1394,74 @@ _lib.llama_grammar_accept_token.argtypes = [
|
|||
llama_token,
|
||||
]
|
||||
_lib.llama_grammar_accept_token.restype = None
|
||||
# //
|
||||
# // Beam search
|
||||
# //
|
||||
|
||||
|
||||
# struct llama_beam_view {
|
||||
# const llama_token * tokens;
|
||||
# size_t n_tokens;
|
||||
# float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||
# };
|
||||
class llama_beam_view(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("tokens", llama_token_p),
|
||||
("n_tokens", c_size_t),
|
||||
("p", c_float),
|
||||
("eob", c_bool),
|
||||
]
|
||||
|
||||
|
||||
# // Passed to beam_search_callback function.
|
||||
# // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||
# // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||
# // These pointers are valid only during the synchronous callback, so should not be saved.
|
||||
# struct llama_beams_state {
|
||||
# struct llama_beam_view * beam_views;
|
||||
# size_t n_beams; // Number of elements in beam_views[].
|
||||
# size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
|
||||
# bool last_call; // True iff this is the last callback invocation.
|
||||
# };
|
||||
class llama_beams_state(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("beam_views", POINTER(llama_beam_view)),
|
||||
("n_beams", c_size_t),
|
||||
("common_prefix_length", c_size_t),
|
||||
("last_call", c_bool),
|
||||
]
|
||||
|
||||
|
||||
# // Type of pointer to the beam_search_callback function.
|
||||
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
||||
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
|
||||
|
||||
|
||||
# /// @details Deterministically returns entire sentence constructed by a beam search.
|
||||
# /// @param ctx Pointer to the llama_context.
|
||||
# /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
|
||||
# /// @param callback_data A pointer that is simply passed back to callback.
|
||||
# /// @param n_beams Number of beams to use.
|
||||
# /// @param n_past Number of tokens already evaluated.
|
||||
# /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||
# /// @param n_threads Number of threads as passed to llama_eval().
|
||||
# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||
def llama_beam_search(
|
||||
ctx: llama_context_p,
|
||||
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
|
||||
callback_data: c_void_p,
|
||||
n_beams: c_size_t,
|
||||
n_past: c_int,
|
||||
n_predict: c_int,
|
||||
n_threads: c_int,
|
||||
):
|
||||
return _lib.llama_beam_search(
|
||||
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads
|
||||
)
|
||||
|
||||
|
||||
# Performance information
|
||||
|
||||
|
@ -1492,6 +1515,16 @@ def llama_log_set(
|
|||
_lib.llama_log_set.argtypes = [llama_log_callback, c_void_p]
|
||||
_lib.llama_log_set.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p):
|
||||
return _lib.llama_dump_timing_info_yaml(stream, ctx)
|
||||
|
||||
|
||||
_lib.llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p]
|
||||
_lib.llama_dump_timing_info_yaml.restype = None
|
||||
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
|
|
@ -579,6 +579,7 @@ def make_logit_bias_processor(
|
|||
@router.post(
|
||||
"/v1/completions",
|
||||
)
|
||||
@router.post("/v1/engines/copilot-codex/completions")
|
||||
async def create_completion(
|
||||
request: Request,
|
||||
body: CreateCompletionRequest,
|
||||
|
|
1
llama_cpp/version.py
Normal file
1
llama_cpp/version.py
Normal file
|
@ -0,0 +1 @@
|
|||
__version__ = "0.1.85"
|
|
@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
|
|||
|
||||
[project]
|
||||
name = "llama_cpp_python"
|
||||
version = "0.1.79"
|
||||
version = "0.1.85"
|
||||
description = "Python bindings for the llama.cpp library"
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
|
|
|
@ -181,3 +181,6 @@ def test_llama_server():
|
|||
}
|
||||
],
|
||||
}
|
||||
|
||||
def test_llama_cpp_version():
|
||||
assert llama_cpp.__version__
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 232caf3c1581a6cb023571780ff41dc2d66d1ca0
|
||||
Subproject commit 89e89599fd095172f8d67903b5e227467420f036
|
Loading…
Reference in a new issue