This commit is contained in:
commit
dc23d15918
10 changed files with 232 additions and 43 deletions
18
CHANGELOG.md
18
CHANGELOG.md
|
@ -7,6 +7,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.2.56]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@c2101a2e909ac7c08976d414e64e96c90ee5fa9e
|
||||
- feat(server): Add endpoints for tokenize, detokenize and count tokens by @felipelo in #1136
|
||||
- feat: Switch embed to llama_get_embeddings_seq by @iamlemec in #1263
|
||||
- fix: Fixed json strings grammar by blacklisting character control set by @ExtReMLapin in d02a9cf16ff88ad011e2eb1ce29f4d9400f13cd1
|
||||
- fix: Check for existence of clip model path by @kejcao in #1264
|
||||
|
||||
## [0.2.55]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/9731134296af3a6839cd682e51d9c2109a871de5
|
||||
- docs: fix small typo in README: 'model know how' -> 'model knows how' by @boegel in #1244
|
||||
|
||||
## [0.2.54]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@cb49e0f8c906e5da49e9f6d64a57742a9a241c6a
|
||||
- docs: fix typo in README.md embeddings example by @iamlemec in #1232
|
||||
|
||||
## [0.2.53]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@cb49e0f8c906e5da49e9f6d64a57742a9a241c6a
|
||||
|
|
13
README.md
13
README.md
|
@ -286,7 +286,16 @@ By default [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest
|
|||
|
||||
The high-level API also provides a simple interface for chat completion.
|
||||
|
||||
Note that `chat_format` option must be set for the particular model you are using.
|
||||
Chat completion requires that the model knows how to format the messages into a single prompt.
|
||||
The `Llama` class does this using pre-registered chat formats (ie. `chatml`, `llama-2`, `gemma`, etc) or by providing a custom chat handler object.
|
||||
|
||||
The model will will format the messages into a single prompt using the following order of precedence:
|
||||
- Use the `chat_handler` if provided
|
||||
- Use the `chat_format` if provided
|
||||
- Use the `tokenizer.chat_template` from the `gguf` model's metadata (should work for most new models, older models may not have this)
|
||||
- else, fallback to the `llama-2` chat format
|
||||
|
||||
Set `verbose=True` to see the selected chat format.
|
||||
|
||||
```python
|
||||
>>> from llama_cpp import Llama
|
||||
|
@ -525,7 +534,7 @@ To generate text embeddings use [`create_embedding`](http://localhost:8000/api-r
|
|||
```python
|
||||
import llama_cpp
|
||||
|
||||
llm = llama_cpp.Llama(model_path="path/to/model.gguf", embeddings=True)
|
||||
llm = llama_cpp.Llama(model_path="path/to/model.gguf", embedding=True)
|
||||
|
||||
embeddings = llm.create_embedding("Hello, world!")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .llama_cpp import *
|
||||
from .llama import *
|
||||
|
||||
__version__ = "0.2.53"
|
||||
__version__ = "0.2.56"
|
|
@ -86,7 +86,6 @@ class Llama:
|
|||
yarn_beta_fast: float = 32.0,
|
||||
yarn_beta_slow: float = 1.0,
|
||||
yarn_orig_ctx: int = 0,
|
||||
mul_mat_q: bool = True,
|
||||
logits_all: bool = False,
|
||||
embedding: bool = False,
|
||||
offload_kqv: bool = True,
|
||||
|
@ -291,11 +290,10 @@ class Llama:
|
|||
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
|
||||
)
|
||||
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
||||
self.context_params.mul_mat_q = mul_mat_q
|
||||
self.context_params.logits_all = (
|
||||
logits_all if draft_model is None else True
|
||||
) # Must be set to True for speculative decoding
|
||||
self.context_params.embedding = embedding
|
||||
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
||||
self.context_params.offload_kqv = offload_kqv
|
||||
|
||||
# Sampling Params
|
||||
|
@ -412,7 +410,7 @@ class Llama:
|
|||
bos_token = self._model.token_get_text(bos_token_id)
|
||||
|
||||
if self.verbose:
|
||||
print(f"Using chat template: {template}", file=sys.stderr)
|
||||
print(f"Using gguf chat template: {template}", file=sys.stderr)
|
||||
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||
|
||||
|
@ -422,6 +420,8 @@ class Llama:
|
|||
|
||||
if self.chat_format is None and self.chat_handler is None:
|
||||
self.chat_format = "llama-2"
|
||||
if self.verbose:
|
||||
print(f"Using fallback chat format: {chat_format}", file=sys.stderr)
|
||||
|
||||
@property
|
||||
def ctx(self) -> llama_cpp.llama_context_p:
|
||||
|
@ -787,7 +787,7 @@ class Llama:
|
|||
n_embd = self.n_embd()
|
||||
n_batch = self.n_batch
|
||||
|
||||
if self.context_params.embedding == False:
|
||||
if self.context_params.embeddings == False:
|
||||
raise RuntimeError(
|
||||
"Llama model must be created with embedding=True to call this method"
|
||||
)
|
||||
|
@ -814,7 +814,7 @@ class Llama:
|
|||
|
||||
# store embeddings
|
||||
for i in range(n_seq):
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
|
||||
self._ctx.ctx, i
|
||||
)[:n_embd]
|
||||
if normalize:
|
||||
|
@ -1724,9 +1724,8 @@ class Llama:
|
|||
yarn_beta_fast=self.context_params.yarn_beta_fast,
|
||||
yarn_beta_slow=self.context_params.yarn_beta_slow,
|
||||
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
|
||||
mul_mat_q=self.context_params.mul_mat_q,
|
||||
logits_all=self.context_params.logits_all,
|
||||
embedding=self.context_params.embedding,
|
||||
embedding=self.context_params.embeddings,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
|
@ -1768,7 +1767,6 @@ class Llama:
|
|||
yarn_beta_fast=state["yarn_beta_fast"],
|
||||
yarn_beta_slow=state["yarn_beta_slow"],
|
||||
yarn_orig_ctx=state["yarn_orig_ctx"],
|
||||
mul_mat_q=state["mul_mat_q"],
|
||||
logits_all=state["logits_all"],
|
||||
embedding=state["embedding"],
|
||||
# Sampling Params
|
||||
|
|
|
@ -1848,6 +1848,9 @@ class Llava15ChatHandler:
|
|||
self.verbose = verbose
|
||||
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
|
||||
|
||||
if not os.path.exists(clip_model_path):
|
||||
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
|
||||
|
||||
with suppress_stdout_stderr(disable=self.verbose):
|
||||
self.clip_ctx = self._llava_cpp.clip_model_load(
|
||||
self.clip_model_path.encode(), 0
|
||||
|
|
|
@ -148,6 +148,12 @@ ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
|
|||
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
|
||||
)
|
||||
|
||||
# // Abort callback
|
||||
# // If not NULL, called before ggml computation
|
||||
# // If it returns true, the computation is aborted
|
||||
# typedef bool (*ggml_abort_callback)(void * data);
|
||||
ggml_abort_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p)
|
||||
|
||||
# llama.h bindings
|
||||
|
||||
_lib.llama_max_devices.argtypes = []
|
||||
|
@ -314,10 +320,12 @@ LLAMA_ROPE_SCALING_TYPE_YARN = 2
|
|||
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
|
||||
|
||||
# enum llama_pooling_type {
|
||||
# LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
|
||||
# LLAMA_POOLING_TYPE_NONE = 0,
|
||||
# LLAMA_POOLING_TYPE_MEAN = 1,
|
||||
# LLAMA_POOLING_TYPE_CLS = 2,
|
||||
# };
|
||||
LLAMA_POOLING_TYPE_UNSPECIFIED = -1
|
||||
LLAMA_POOLING_TYPE_NONE = 0
|
||||
LLAMA_POOLING_TYPE_MEAN = 1
|
||||
LLAMA_POOLING_TYPE_CLS = 2
|
||||
|
@ -391,7 +399,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(
|
|||
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
# // - pos : the positions of the respective token in the sequence
|
||||
# // - seq_id : the sequence to which the respective token belongs
|
||||
# // - logits : if zero, the logits for the respective token will not be output
|
||||
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
# //
|
||||
# typedef struct llama_batch {
|
||||
# int32_t n_tokens;
|
||||
|
@ -401,7 +409,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(
|
|||
# llama_pos * pos;
|
||||
# int32_t * n_seq_id;
|
||||
# llama_seq_id ** seq_id;
|
||||
# int8_t * logits;
|
||||
# int8_t * logits; // TODO: rename this to "output"
|
||||
|
||||
|
||||
# // NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||
|
@ -421,10 +429,12 @@ class llama_batch(ctypes.Structure):
|
|||
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
|
||||
Attributes:
|
||||
n_tokens (int): number of tokens
|
||||
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
|
||||
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
|
||||
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
|
||||
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
|
||||
"""
|
||||
|
||||
_fields_ = [
|
||||
|
@ -539,9 +549,13 @@ class llama_model_params(ctypes.Structure):
|
|||
# uint32_t seed; // RNG seed, -1 for random
|
||||
# uint32_t n_ctx; // text context, 0 = from model
|
||||
# uint32_t n_batch; // prompt processing maximum batch size
|
||||
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
|
||||
# uint32_t n_threads; // number of threads to use for generation
|
||||
# uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
# int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||
# // (ignored if no pooling layer)
|
||||
|
||||
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
# float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
|
@ -559,13 +573,16 @@ class llama_model_params(ctypes.Structure):
|
|||
# enum ggml_type type_k; // data type for K cache
|
||||
# enum ggml_type type_v; // data type for V cache
|
||||
|
||||
|
||||
# // Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
# bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
# bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
# bool embedding; // embedding mode only
|
||||
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
# bool embeddings; // if true, extract embeddings (together with logits)
|
||||
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
|
||||
|
||||
# // Abort callback
|
||||
# // if it returns true, execution of llama_decode() will be aborted
|
||||
# // currently works only with CPU execution
|
||||
# ggml_abort_callback abort_callback;
|
||||
# void * abort_callback_data;
|
||||
# };
|
||||
class llama_context_params(ctypes.Structure):
|
||||
"""Parameters for llama_context
|
||||
|
@ -574,9 +591,11 @@ class llama_context_params(ctypes.Structure):
|
|||
seed (int): RNG seed, -1 for random
|
||||
n_ctx (int): text context, 0 = from model
|
||||
n_batch (int): prompt processing maximum batch size
|
||||
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
|
||||
n_threads (int): number of threads to use for generation
|
||||
n_threads_batch (int): number of threads to use for batch processing
|
||||
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
|
||||
rope_freq_base (float): RoPE base frequency, 0 = from model
|
||||
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
|
||||
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
|
||||
|
@ -589,20 +608,22 @@ class llama_context_params(ctypes.Structure):
|
|||
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
|
||||
type_k (int): data type for K cache
|
||||
type_v (int): data type for V cache
|
||||
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
embedding (bool): embedding mode only
|
||||
embeddings (bool): if true, extract embeddings (together with logits)
|
||||
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
|
||||
do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
|
||||
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
|
||||
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
|
||||
"""
|
||||
|
||||
_fields_ = [
|
||||
("seed", ctypes.c_uint32),
|
||||
("n_ctx", ctypes.c_uint32),
|
||||
("n_batch", ctypes.c_uint32),
|
||||
("n_parallel", ctypes.c_uint32),
|
||||
("n_threads", ctypes.c_uint32),
|
||||
("n_threads_batch", ctypes.c_uint32),
|
||||
("rope_scaling_type", ctypes.c_int32),
|
||||
("rope_scaling_type", ctypes.c_int),
|
||||
("pooling_type", ctypes.c_int),
|
||||
("rope_freq_base", ctypes.c_float),
|
||||
("rope_freq_scale", ctypes.c_float),
|
||||
("yarn_ext_factor", ctypes.c_float),
|
||||
|
@ -615,11 +636,11 @@ class llama_context_params(ctypes.Structure):
|
|||
("cb_eval_user_data", ctypes.c_void_p),
|
||||
("type_k", ctypes.c_int),
|
||||
("type_v", ctypes.c_int),
|
||||
("mul_mat_q", ctypes.c_bool),
|
||||
("logits_all", ctypes.c_bool),
|
||||
("embedding", ctypes.c_bool),
|
||||
("embeddings", ctypes.c_bool),
|
||||
("offload_kqv", ctypes.c_bool),
|
||||
("do_pooling", ctypes.c_bool),
|
||||
("abort_callback", ggml_abort_callback),
|
||||
("abort_callback_data", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
|
||||
|
@ -1306,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
|
|||
# // seq_id < 0 : match any sequence
|
||||
# // p0 < 0 : [0, p1]
|
||||
# // p1 < 0 : [p0, inf)
|
||||
# LLAMA_API void llama_kv_cache_seq_rm(
|
||||
# LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
# struct llama_context * ctx,
|
||||
# llama_seq_id seq_id,
|
||||
# llama_pos p0,
|
||||
|
@ -1319,7 +1340,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
|
|||
llama_pos,
|
||||
llama_pos,
|
||||
],
|
||||
None,
|
||||
ctypes.c_bool,
|
||||
)
|
||||
def llama_kv_cache_seq_rm(
|
||||
ctx: llama_context_p,
|
||||
|
@ -1327,7 +1348,7 @@ def llama_kv_cache_seq_rm(
|
|||
p0: Union[llama_pos, int],
|
||||
p1: Union[llama_pos, int],
|
||||
/,
|
||||
):
|
||||
) -> bool:
|
||||
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
seq_id < 0 : match any sequence
|
||||
p0 < 0 : [0, p1]
|
||||
|
@ -1519,11 +1540,11 @@ def llama_copy_state_data(
|
|||
...
|
||||
|
||||
|
||||
# Set the state reading from the specified address
|
||||
# Returns the number of bytes read
|
||||
# // Set the state reading from the specified address
|
||||
# // Returns the number of bytes read
|
||||
# LLAMA_API size_t llama_set_state_data(
|
||||
# struct llama_context * ctx,
|
||||
# uint8_t * src);
|
||||
# const uint8_t * src);
|
||||
@ctypes_function(
|
||||
"llama_set_state_data",
|
||||
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
|
||||
|
@ -1707,8 +1728,24 @@ def llama_set_n_threads(
|
|||
"""
|
||||
...
|
||||
|
||||
# // Set abort callback
|
||||
# LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||
@ctypes_function(
|
||||
"llama_set_abort_callback",
|
||||
[llama_context_p_ctypes, ggml_abort_callback, ctypes.c_void_p],
|
||||
None,
|
||||
)
|
||||
def llama_set_abort_callback(
|
||||
ctx: llama_context_p,
|
||||
abort_callback: Callable[[ctypes.c_void_p], None],
|
||||
abort_callback_data: ctypes.c_void_p,
|
||||
/,
|
||||
):
|
||||
"""Set abort callback"""
|
||||
...
|
||||
|
||||
# // Token logits obtained from the last call to llama_eval()
|
||||
|
||||
# // Token logits obtained from the last call to llama_decode()
|
||||
# // The logits for the last token are stored in the last row
|
||||
# // Logits for which llama_batch.logits[i] == 0 are undefined
|
||||
# // Rows: n_tokens provided with llama_batch
|
||||
|
@ -1722,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
|
|||
The logits for the last token are stored in the last row
|
||||
Logits for which llama_batch.logits[i] == 0 are undefined
|
||||
Rows: n_tokens provided with llama_batch
|
||||
Cols: n_vocab"""
|
||||
Cols: n_vocab
|
||||
|
||||
Returns:
|
||||
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
|
||||
...
|
||||
|
||||
|
||||
|
@ -1742,8 +1782,8 @@ def llama_get_logits_ith(
|
|||
...
|
||||
|
||||
|
||||
# Get the embeddings for the input
|
||||
# shape: [n_embd] (1-dimensional)
|
||||
# // Get all output token embeddings
|
||||
# // shape: [n_tokens*n_embd] (1-dimensional)
|
||||
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
@ctypes_function(
|
||||
"llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
|
||||
|
@ -1754,8 +1794,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
|
|||
...
|
||||
|
||||
|
||||
# // Get the embeddings for the ith sequence
|
||||
# // Get the embeddings for the ith token
|
||||
# // llama_get_embeddings(ctx) + i*n_embd
|
||||
# // shape: [n_embd] (1-dimensional)
|
||||
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
@ctypes_function(
|
||||
"llama_get_embeddings_ith",
|
||||
|
@ -1770,6 +1811,23 @@ def llama_get_embeddings_ith(
|
|||
...
|
||||
|
||||
|
||||
# // Get the embeddings for a sequence id
|
||||
# // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
# // shape: [n_embd] (1-dimensional)
|
||||
# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
@ctypes_function(
|
||||
"llama_get_embeddings_seq",
|
||||
[llama_context_p_ctypes, llama_seq_id],
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
)
|
||||
def llama_get_embeddings_seq(
|
||||
ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /
|
||||
) -> CtypesArray[ctypes.c_float]:
|
||||
"""Get the embeddings for a sequence id
|
||||
Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
shape: [n_embd] (1-dimensional)"""
|
||||
...
|
||||
|
||||
# //
|
||||
# // Vocab
|
||||
# //
|
||||
|
|
|
@ -1337,7 +1337,7 @@ array ::=
|
|||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
|
@ -1366,7 +1366,7 @@ array ::=
|
|||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
|
|
|
@ -41,6 +41,11 @@ from llama_cpp.server.types import (
|
|||
CreateEmbeddingRequest,
|
||||
CreateChatCompletionRequest,
|
||||
ModelList,
|
||||
TokenizeInputRequest,
|
||||
TokenizeInputResponse,
|
||||
TokenizeInputCountResponse,
|
||||
DetokenizeInputRequest,
|
||||
DetokenizeInputResponse,
|
||||
)
|
||||
from llama_cpp.server.errors import RouteErrorHandler
|
||||
|
||||
|
@ -196,6 +201,9 @@ async def authenticate(
|
|||
)
|
||||
|
||||
|
||||
openai_v1_tag = "OpenAI V1"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
summary="Completion",
|
||||
|
@ -227,11 +235,13 @@ async def authenticate(
|
|||
},
|
||||
}
|
||||
},
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/engines/copilot-codex/completions",
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_completion(
|
||||
request: Request,
|
||||
|
@ -297,7 +307,10 @@ async def create_completion(
|
|||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings", summary="Embedding", dependencies=[Depends(authenticate)]
|
||||
"/v1/embeddings",
|
||||
summary="Embedding",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_embedding(
|
||||
request: CreateEmbeddingRequest,
|
||||
|
@ -339,6 +352,7 @@ async def create_embedding(
|
|||
},
|
||||
}
|
||||
},
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: Request,
|
||||
|
@ -391,7 +405,12 @@ async def create_chat_completion(
|
|||
return iterator_or_completion
|
||||
|
||||
|
||||
@router.get("/v1/models", summary="Models", dependencies=[Depends(authenticate)])
|
||||
@router.get(
|
||||
"/v1/models",
|
||||
summary="Models",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def get_models(
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> ModelList:
|
||||
|
@ -407,3 +426,51 @@ async def get_models(
|
|||
for model_alias in llama_proxy
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
extras_tag = "Extras"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/tokenize",
|
||||
summary="Tokenize",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def tokenize(
|
||||
body: TokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> TokenizeInputResponse:
|
||||
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
|
||||
|
||||
return {"tokens": tokens}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/tokenize/count",
|
||||
summary="Tokenize Count",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def count_query_tokens(
|
||||
body: TokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> TokenizeInputCountResponse:
|
||||
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
|
||||
|
||||
return {"count": len(tokens)}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/detokenize",
|
||||
summary="Detokenize",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def detokenize(
|
||||
body: DetokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> DetokenizeInputResponse:
|
||||
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
|
||||
|
||||
return {"text": text}
|
||||
|
|
|
@ -264,3 +264,39 @@ class ModelData(TypedDict):
|
|||
class ModelList(TypedDict):
|
||||
object: Literal["list"]
|
||||
data: List[ModelData]
|
||||
|
||||
|
||||
class TokenizeInputRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
input: Optional[str] = Field(description="The input to tokenize.")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]}
|
||||
}
|
||||
|
||||
|
||||
class TokenizeInputResponse(BaseModel):
|
||||
tokens: List[int] = Field(description="A list of tokens.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"tokens": [123, 321, 222]}}}
|
||||
|
||||
|
||||
class TokenizeInputCountResponse(BaseModel):
|
||||
count: int = Field(description="The number of tokens in the input.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"count": 5}}}
|
||||
|
||||
|
||||
class DetokenizeInputRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
tokens: List[int] = Field(description="A list of toekns to detokenize.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": [{"tokens": [123, 321, 222]}]}}
|
||||
|
||||
|
||||
class DetokenizeInputResponse(BaseModel):
|
||||
text: str = Field(description="The detokenized text.")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}}
|
||||
}
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 08c5ee87e4cceb603ecceac90734fcdade57311b
|
||||
Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e
|
Loading…
Reference in a new issue