This commit is contained in:
Mug 2023-04-26 14:38:09 +02:00
commit be2c961bc9
8 changed files with 203 additions and 113 deletions

View file

@ -2,7 +2,11 @@ cmake_minimum_required(VERSION 3.4...3.22)
project(llama_cpp) project(llama_cpp)
if (UNIX) option(FORCE_CMAKE "Force CMake build of Python bindings" OFF)
set(FORCE_CMAKE $ENV{FORCE_CMAKE})
if (UNIX AND NOT FORCE_CMAKE)
add_custom_command( add_custom_command(
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/libllama.so
COMMAND make libllama.so COMMAND make libllama.so

View file

@ -105,12 +105,16 @@ python3 setup.py develop
- __call__ - __call__
- create_chat_completion - create_chat_completion
- set_cache - set_cache
- save_state
- load_state
- token_bos - token_bos
- token_eos - token_eos
show_root_heading: true show_root_heading: true
::: llama_cpp.LlamaCache ::: llama_cpp.LlamaCache
::: llama_cpp.LlamaState
::: llama_cpp.llama_cpp ::: llama_cpp.llama_cpp
options: options:
show_if_no_docstring: true show_if_no_docstring: true

View file

@ -4,7 +4,7 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque from collections import deque
from . import llama_cpp from . import llama_cpp
@ -12,12 +12,53 @@ from .llama_types import *
class LlamaCache: class LlamaCache:
"""Cache for a llama.cpp model. """Cache for a llama.cpp model."""
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last def __init__(self):
completion. It does not actually cache the results.""" self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
pass def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
return [
key
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]
def _find_key(
self, key: Tuple[llama_cpp.llama_token, ...]
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys():
if key[: len(k)] == k:
return k
return None
def __getitem__(
self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]:
_key = self._find_key(tuple(key))
if _key is None:
return None
return self.cache_state[_key]
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return self._find_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
self.cache_state[tuple(key)] = value
class LlamaState:
def __init__(
self,
eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]],
llama_state,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.llama_state = llama_state
class Llama: class Llama:
@ -37,8 +78,10 @@ class Llama:
use_mlock: bool = False, use_mlock: bool = False,
embedding: bool = False, embedding: bool = False,
n_threads: Optional[int] = None, n_threads: Optional[int] = None,
n_batch: int = 8, n_batch: int = 512,
last_n_tokens_size: int = 64, last_n_tokens_size: int = 64,
lora_base: Optional[str] = None,
lora_path: Optional[str] = None,
verbose: bool = True, verbose: bool = True,
): ):
"""Load a llama.cpp model from `model_path`. """Load a llama.cpp model from `model_path`.
@ -57,6 +100,8 @@ class Llama:
n_threads: Number of threads to use. If None, the number of threads is automatically determined. n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -75,32 +120,22 @@ class Llama:
self.params.f16_kv = f16_kv self.params.f16_kv = f16_kv
self.params.logits_all = logits_all self.params.logits_all = logits_all
self.params.vocab_only = vocab_only self.params.vocab_only = vocab_only
self.params.use_mmap = use_mmap self.params.use_mmap = use_mmap if lora_path is None else False
self.params.use_mlock = use_mlock self.params.use_mlock = use_mlock
self.params.embedding = embedding self.params.embedding = embedding
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.last_n_tokens_data = deque(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
maxlen=self.last_n_tokens_size,
)
self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0 self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.n_past = 0 self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support self.cache: Optional[LlamaCache] = None
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.lora_base = lora_base
self.lora_path = lora_path
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}") raise ValueError(f"Model path does not exist: {model_path}")
@ -108,6 +143,21 @@ class Llama:
self.model_path.encode("utf-8"), self.params self.model_path.encode("utf-8"), self.params
) )
assert self.ctx is not None
if self.lora_path:
if llama_cpp.llama_apply_lora_from_file(
self.ctx,
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
if self.lora_base is not None
else llama_cpp.c_char_p(0),
llama_cpp.c_int(self.n_threads),
):
raise RuntimeError(
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
)
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
@ -158,18 +208,12 @@ class Llama:
Args: Args:
cache: The cache to set. cache: The cache to set.
""" """
self._cache = cache self.cache = cache
def reset(self): def reset(self):
"""Reset the model state.""" """Reset the model state."""
self.last_n_tokens_data.extend( self.eval_tokens.clear()
[llama_cpp.llama_token(0)] * self.last_n_tokens_size self.eval_logits.clear()
)
self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0
self.n_past = 0
self.all_logits.clear()
def eval(self, tokens: Sequence[llama_cpp.llama_token]): def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -181,32 +225,28 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
self.n_past = min(n_ctx - len(batch), self.tokens_consumed) n_past = min(n_ctx - len(batch), len(self.eval_tokens))
self.n_tokens = len(batch) n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch), tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(self.n_tokens), n_tokens=llama_cpp.c_int(n_tokens),
n_past=llama_cpp.c_int(self.n_past), n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch) self.eval_tokens.extend(batch)
self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch)
if self.params.logits_all: if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx) n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab) cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1 rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] logits = [
return logits [logits_view[i * cols + j] for j in range(cols)]
for i in range(rows)
]
self.eval_logits.extend(logits)
def sample( def sample(
self, self,
@ -227,10 +267,13 @@ class Llama:
The sampled token. The sampled token.
""" """
assert self.ctx is not None assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return llama_cpp.llama_sample_top_p_top_k( return llama_cpp.llama_sample_top_p_top_k(
ctx=self.ctx, ctx=self.ctx,
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*self.last_n_tokens_data *last_n_tokens_data
), ),
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size), last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
top_k=llama_cpp.c_int(top_k), top_k=llama_cpp.c_int(top_k),
@ -270,18 +313,17 @@ class Llama:
The generated tokens. The generated tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
### HACK
if ( if (
reset reset
and self._cache and len(self.eval_tokens) > 0
and len(self.tokens) > 0 and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
and self.tokens == tokens[: len(self.tokens)]
): ):
if self.verbose: if self.verbose:
print("generate cache hit", file=sys.stderr) print("generate cache hit", file=sys.stderr)
reset = False reset = False
tokens = tokens[len(self.tokens) :] tokens = tokens[len(self.eval_tokens) :]
###
if reset: if reset:
self.reset() self.reset()
while True: while True:
@ -398,20 +440,10 @@ class Llama:
"logprobs is not supported for models created with logits_all=False" "logprobs is not supported for models created with logits_all=False"
) )
### HACK if self.cache and prompt_tokens in self.cache:
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose: if self.verbose:
print("completion cache hit", file=sys.stderr) print("cache hit", file=sys.stderr)
reset = False self.load_state(self.cache[prompt_tokens])
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
finish_reason = "length" finish_reason = "length"
for token in self.generate( for token in self.generate(
@ -420,12 +452,18 @@ class Llama:
top_p=top_p, top_p=top_p,
temp=temperature, temp=temperature,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
reset=reset,
): ):
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "stop" finish_reason = "stop"
break break
if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
if self.verbose:
print("cache miss", file=sys.stderr)
self.cache[prompt_tokens] = self.save_state()
completion_tokens.append(token) completion_tokens.append(token)
all_text = self.detokenize(completion_tokens) all_text = self.detokenize(completion_tokens)
@ -450,9 +488,6 @@ class Llama:
break break
text = all_text[: len(all_text) - longest] text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:]) returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -474,9 +509,6 @@ class Llama:
break break
if stream: if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -493,9 +525,6 @@ class Llama:
} }
return return
### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8", errors="ignore") text_str = text.decode("utf-8", errors="ignore")
if echo: if echo:
@ -518,7 +547,7 @@ class Llama:
] ]
all_logprobs = [ all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row] [Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits for row in self.eval_logits
] ]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs all_tokens, all_token_strs, all_logprobs
@ -802,6 +831,8 @@ class Llama:
last_n_tokens_size=self.last_n_tokens_size, last_n_tokens_size=self.last_n_tokens_size,
n_batch=self.n_batch, n_batch=self.n_batch,
n_threads=self.n_threads, n_threads=self.n_threads,
lora_base=self.lora_base,
lora_path=self.lora_path,
) )
def __setstate__(self, state): def __setstate__(self, state):
@ -819,9 +850,31 @@ class Llama:
n_threads=state["n_threads"], n_threads=state["n_threads"],
n_batch=state["n_batch"], n_batch=state["n_batch"],
last_n_tokens_size=state["last_n_tokens_size"], last_n_tokens_size=state["last_n_tokens_size"],
lora_base=state["lora_base"],
lora_path=state["lora_path"],
verbose=state["verbose"], verbose=state["verbose"],
) )
def save_state(self) -> LlamaState:
assert self.ctx is not None
state_size = llama_cpp.llama_get_state_size(self.ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to copy llama state data")
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
llama_state=llama_state,
)
def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
state_size = llama_cpp.llama_get_state_size(self.ctx)
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
@staticmethod @staticmethod
def token_eos() -> llama_cpp.llama_token: def token_eos() -> llama_cpp.llama_token:
"""Return the end-of-sequence token.""" """Return the end-of-sequence token."""

View file

@ -114,7 +114,12 @@ LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16 LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(
4
) # tok_embeddings.weight and output.weight are F16
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes.c_int(7) # except 1d tensors
# Functions # Functions
@ -167,31 +172,34 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change # TODO: not great API - very likely to change
# Returns 0 on success # Returns 0 on success
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int: # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
return _lib.llama_model_quantize(fname_inp, fname_out, itype) def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int] _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int, c_int]
_lib.llama_model_quantize.restype = c_int _lib.llama_model_quantize.restype = c_int
# Returns the KV cache that will contain the context for the # Apply a LoRA adapter to a loaded model
# ongoing prediction with the model. # path_base_model is the path to a higher quality model to use as a base for
def llama_get_kv_cache(ctx: llama_context_p): # the layers modified by the adapter. Can be NULL to use the current loaded model.
return _lib.llama_get_kv_cache(ctx) # The model needs to be reloaded before applying a new adapter, otherwise the adapter
# will be applied on top of the previous one
# Returns 0 on success
def llama_apply_lora_from_file(
ctx: llama_context_p,
path_lora: ctypes.c_char_p,
path_base_model: ctypes.c_char_p,
n_threads: c_int,
) -> c_int:
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
_lib.llama_get_kv_cache.argtypes = [llama_context_p] _lib.llama_apply_lora_from_file.argtypes = [llama_context_p, c_char_p, c_char_p, c_int]
_lib.llama_get_kv_cache.restype = POINTER(c_uint8) _lib.llama_apply_lora_from_file.restype = c_int
# Returns the size of the KV cache
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_kv_cache_size(ctx)
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_size.restype = c_size_t
# Returns the number of tokens in the KV cache # Returns the number of tokens in the KV cache
@ -203,15 +211,34 @@ _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int _lib.llama_get_kv_cache_token_count.restype = c_int
# Sets the KV cache containing the current context for the model # Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
def llama_set_kv_cache( def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int return _lib.llama_get_state_size(ctx)
):
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int] _lib.llama_get_state_size.argtypes = [llama_context_p]
_lib.llama_set_kv_cache.restype = None _lib.llama_get_state_size.restype = c_size_t
# Copies the state to the specified destination address.
# Destination needs to have allocated enough memory.
# Returns the number of bytes copied
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
return _lib.llama_copy_state_data(ctx, dest)
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_copy_state_data.restype = c_size_t
# Set the state reading from the specified address
# Returns the number of bytes read
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
return _lib.llama_set_state_data(ctx, src)
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_set_state_data.restype = c_size_t
# Run the llama inference to obtain the logits and probabilities for the next token. # Run the llama inference to obtain the logits and probabilities for the next token.

View file

@ -28,10 +28,11 @@ from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings): class Settings(BaseSettings):
model: str model: str
n_ctx: int = 2048 n_ctx: int = 2048
n_batch: int = 8 n_batch: int = 512
n_threads: int = ((os.cpu_count() or 2) // 2) or 1 n_threads: int = max((os.cpu_count() or 2) // 2, 1)
f16_kv: bool = True f16_kv: bool = True
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
use_mmap: bool = True
embedding: bool = True embedding: bool = True
last_n_tokens_size: int = 64 last_n_tokens_size: int = 64
logits_all: bool = False logits_all: bool = False
@ -54,6 +55,7 @@ llama = llama_cpp.Llama(
settings.model, settings.model,
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
embedding=settings.embedding, embedding=settings.embedding,
logits_all=settings.logits_all, logits_all=settings.logits_all,
n_threads=settings.n_threads, n_threads=settings.n_threads,

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llama_cpp_python" name = "llama_cpp_python"
version = "0.1.34" version = "0.1.38"
description = "Python bindings for the llama.cpp library" description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"] authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT" license = "MIT"

View file

@ -10,7 +10,7 @@ setup(
description="A Python wrapper for llama.cpp", description="A Python wrapper for llama.cpp",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
version="0.1.34", version="0.1.38",
author="Andrei Betlen", author="Andrei Betlen",
author_email="abetlen@gmail.com", author_email="abetlen@gmail.com",
license="MIT", license="MIT",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit e95b6554b493e71a0275764342e09bd5784a7026 Subproject commit 4afcc378698e057fcde64e23eb664e5af8dd6956