feat: Pull models directly from huggingface (#1206)
* Add from_pretrained method to Llama class * Update docs * Merge filename and pattern
This commit is contained in:
parent
e42f62c247
commit
0f8aa4ab5c
3 changed files with 136 additions and 27 deletions
13
README.md
13
README.md
|
@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi
|
||||||
|
|
||||||
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
|
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
|
||||||
|
|
||||||
|
## Pulling models from Hugging Face
|
||||||
|
|
||||||
|
You can pull `Llama` models from Hugging Face using the `from_pretrained` method.
|
||||||
|
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
|
||||||
|
|
||||||
|
```python
|
||||||
|
llama = Llama.from_pretrained(
|
||||||
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
|
||||||
|
filename="*q8_0.gguf",
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Chat Completion
|
### Chat Completion
|
||||||
|
|
||||||
The high-level API also provides a simple interface for chat completion.
|
The high-level API also provides a simple interface for chat completion.
|
||||||
|
|
|
@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp.
|
||||||
- load_state
|
- load_state
|
||||||
- token_bos
|
- token_bos
|
||||||
- token_eos
|
- token_eos
|
||||||
|
- from_pretrained
|
||||||
show_root_heading: true
|
show_root_heading: true
|
||||||
|
|
||||||
::: llama_cpp.LlamaGrammar
|
::: llama_cpp.LlamaGrammar
|
||||||
|
|
|
@ -4,6 +4,8 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
import fnmatch
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
|
@ -16,6 +18,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
)
|
)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
|
@ -29,10 +32,7 @@ from .llama_cache import (
|
||||||
LlamaDiskCache, # type: ignore
|
LlamaDiskCache, # type: ignore
|
||||||
LlamaRAMCache, # type: ignore
|
LlamaRAMCache, # type: ignore
|
||||||
)
|
)
|
||||||
from .llama_tokenizer import (
|
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
|
||||||
BaseLlamaTokenizer,
|
|
||||||
LlamaTokenizer
|
|
||||||
)
|
|
||||||
import llama_cpp.llama_cpp as llama_cpp
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
import llama_cpp.llama_chat_format as llama_chat_format
|
import llama_cpp.llama_chat_format as llama_chat_format
|
||||||
|
|
||||||
|
@ -50,9 +50,7 @@ from ._internals import (
|
||||||
_LlamaSamplingContext, # type: ignore
|
_LlamaSamplingContext, # type: ignore
|
||||||
)
|
)
|
||||||
from ._logger import set_verbose
|
from ._logger import set_verbose
|
||||||
from ._utils import (
|
from ._utils import suppress_stdout_stderr
|
||||||
suppress_stdout_stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
|
@ -189,7 +187,11 @@ class Llama:
|
||||||
Llama.__backend_initialized = True
|
Llama.__backend_initialized = True
|
||||||
|
|
||||||
if isinstance(numa, bool):
|
if isinstance(numa, bool):
|
||||||
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
self.numa = (
|
||||||
|
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
|
||||||
|
if numa
|
||||||
|
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.numa = numa
|
self.numa = numa
|
||||||
|
|
||||||
|
@ -246,9 +248,9 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||||
|
|
||||||
self._kv_overrides_array[
|
self._kv_overrides_array[-1].key = (
|
||||||
-1
|
b"\0" # ensure sentinel element is zeroed
|
||||||
].key = b"\0" # ensure sentinel element is zeroed
|
)
|
||||||
self.model_params.kv_overrides = self._kv_overrides_array
|
self.model_params.kv_overrides = self._kv_overrides_array
|
||||||
|
|
||||||
self.n_batch = min(n_ctx, n_batch) # ???
|
self.n_batch = min(n_ctx, n_batch) # ???
|
||||||
|
@ -289,7 +291,9 @@ class Llama:
|
||||||
)
|
)
|
||||||
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
||||||
self.context_params.mul_mat_q = mul_mat_q
|
self.context_params.mul_mat_q = mul_mat_q
|
||||||
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
|
self.context_params.logits_all = (
|
||||||
|
logits_all if draft_model is None else True
|
||||||
|
) # Must be set to True for speculative decoding
|
||||||
self.context_params.embedding = embedding
|
self.context_params.embedding = embedding
|
||||||
self.context_params.offload_kqv = offload_kqv
|
self.context_params.offload_kqv = offload_kqv
|
||||||
|
|
||||||
|
@ -379,8 +383,14 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||||
|
|
||||||
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
|
if (
|
||||||
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
|
self.chat_format is None
|
||||||
|
and self.chat_handler is None
|
||||||
|
and "tokenizer.chat_template" in self.metadata
|
||||||
|
):
|
||||||
|
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
|
||||||
|
self.metadata
|
||||||
|
)
|
||||||
|
|
||||||
if chat_format is not None:
|
if chat_format is not None:
|
||||||
self.chat_format = chat_format
|
self.chat_format = chat_format
|
||||||
|
@ -406,9 +416,7 @@ class Llama:
|
||||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||||
|
|
||||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||||
template=template,
|
template=template, eos_token=eos_token, bos_token=bos_token
|
||||||
eos_token=eos_token,
|
|
||||||
bos_token=bos_token
|
|
||||||
).to_chat_handler()
|
).to_chat_handler()
|
||||||
|
|
||||||
if self.chat_format is None and self.chat_handler is None:
|
if self.chat_format is None and self.chat_handler is None:
|
||||||
|
@ -459,7 +467,9 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
return self.tokenizer_.tokenize(text, add_bos, special)
|
return self.tokenizer_.tokenize(text, add_bos, special)
|
||||||
|
|
||||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
def detokenize(
|
||||||
|
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||||
|
) -> bytes:
|
||||||
"""Detokenize a list of tokens.
|
"""Detokenize a list of tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -565,7 +575,7 @@ class Llama:
|
||||||
logits[:] = (
|
logits[:] = (
|
||||||
logits_processor(self._input_ids, logits)
|
logits_processor(self._input_ids, logits)
|
||||||
if idx is None
|
if idx is None
|
||||||
else logits_processor(self._input_ids[:idx + 1], logits)
|
else logits_processor(self._input_ids[: idx + 1], logits)
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = _LlamaSamplingParams(
|
sampling_params = _LlamaSamplingParams(
|
||||||
|
@ -707,7 +717,9 @@ class Llama:
|
||||||
|
|
||||||
if self.draft_model is not None:
|
if self.draft_model is not None:
|
||||||
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
||||||
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
|
draft_tokens = self.draft_model(
|
||||||
|
self.input_ids[: self.n_tokens + len(tokens)]
|
||||||
|
)
|
||||||
tokens.extend(
|
tokens.extend(
|
||||||
draft_tokens.astype(int)[
|
draft_tokens.astype(int)[
|
||||||
: self._n_ctx - self.n_tokens - len(tokens)
|
: self._n_ctx - self.n_tokens - len(tokens)
|
||||||
|
@ -792,6 +804,7 @@ class Llama:
|
||||||
|
|
||||||
# decode and fetch embeddings
|
# decode and fetch embeddings
|
||||||
data: List[List[float]] = []
|
data: List[List[float]] = []
|
||||||
|
|
||||||
def decode_batch(n_seq: int):
|
def decode_batch(n_seq: int):
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||||
|
@ -800,9 +813,9 @@ class Llama:
|
||||||
|
|
||||||
# store embeddings
|
# store embeddings
|
||||||
for i in range(n_seq):
|
for i in range(n_seq):
|
||||||
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
|
||||||
:n_embd
|
self._ctx.ctx, i
|
||||||
]
|
)[:n_embd]
|
||||||
if normalize:
|
if normalize:
|
||||||
norm = float(np.linalg.norm(embedding))
|
norm = float(np.linalg.norm(embedding))
|
||||||
embedding = [v / norm for v in embedding]
|
embedding = [v / norm for v in embedding]
|
||||||
|
@ -1669,12 +1682,13 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
stream = kwargs.get("stream", False) # type: ignore
|
|
||||||
|
stream = kwargs.get("stream", False) # type: ignore
|
||||||
assert isinstance(stream, bool)
|
assert isinstance(stream, bool)
|
||||||
if stream:
|
if stream:
|
||||||
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
|
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||||
else:
|
else:
|
||||||
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
|
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"To use create_chat_completion_openai_v1, you must install the openai package."
|
"To use create_chat_completion_openai_v1, you must install the openai package."
|
||||||
|
@ -1866,7 +1880,88 @@ class Llama:
|
||||||
break
|
break
|
||||||
return longest_prefix
|
return longest_prefix
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
repo_id: str,
|
||||||
|
filename: Optional[str],
|
||||||
|
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
|
||||||
|
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Llama":
|
||||||
|
"""Create a Llama model from a pretrained model name or path.
|
||||||
|
This method requires the huggingface-hub package.
|
||||||
|
You can install it with `pip install huggingface-hub`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: The model repo id.
|
||||||
|
filename: A filename or glob pattern to match the model file in the repo.
|
||||||
|
local_dir: The local directory to save the model to.
|
||||||
|
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the Llama constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Llama model."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download, HfFileSystem
|
||||||
|
from huggingface_hub.utils import validate_repo_id
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Llama.from_pretrained requires the huggingface-hub package. "
|
||||||
|
"You can install it with `pip install huggingface-hub`."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_repo_id(repo_id)
|
||||||
|
|
||||||
|
hffs = HfFileSystem()
|
||||||
|
|
||||||
|
files = [
|
||||||
|
file["name"] if isinstance(file, dict) else file
|
||||||
|
for file in hffs.ls(repo_id)
|
||||||
|
]
|
||||||
|
|
||||||
|
# split each file into repo_id, subfolder, filename
|
||||||
|
file_list: List[str] = []
|
||||||
|
for file in files:
|
||||||
|
rel_path = Path(file).relative_to(repo_id)
|
||||||
|
file_list.append(str(rel_path))
|
||||||
|
|
||||||
|
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
|
||||||
|
|
||||||
|
if len(matching_files) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No file found in {repo_id} that match {filename}\n\n"
|
||||||
|
f"Available Files:\n{json.dumps(file_list)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(matching_files) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple files found in {repo_id} matching {filename}\n\n"
|
||||||
|
f"Available Files:\n{json.dumps(files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
(matching_file,) = matching_files
|
||||||
|
|
||||||
|
subfolder = str(Path(matching_file).parent)
|
||||||
|
filename = Path(matching_file).name
|
||||||
|
|
||||||
|
local_dir = "."
|
||||||
|
|
||||||
|
# download the file
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
local_dir=local_dir,
|
||||||
|
filename=filename,
|
||||||
|
subfolder=subfolder,
|
||||||
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model_path=model_path,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaState:
|
class LlamaState:
|
||||||
|
|
Loading…
Reference in a new issue