feat: Pull models directly from huggingface (#1206)
* Add from_pretrained method to Llama class * Update docs * Merge filename and pattern
This commit is contained in:
parent
e42f62c247
commit
0f8aa4ab5c
3 changed files with 136 additions and 27 deletions
13
README.md
13
README.md
|
@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi
|
|||
|
||||
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
|
||||
|
||||
## Pulling models from Hugging Face
|
||||
|
||||
You can pull `Llama` models from Hugging Face using the `from_pretrained` method.
|
||||
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
|
||||
|
||||
```python
|
||||
llama = Llama.from_pretrained(
|
||||
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
|
||||
filename="*q8_0.gguf",
|
||||
verbose=False
|
||||
)
|
||||
```
|
||||
|
||||
### Chat Completion
|
||||
|
||||
The high-level API also provides a simple interface for chat completion.
|
||||
|
|
|
@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp.
|
|||
- load_state
|
||||
- token_bos
|
||||
- token_eos
|
||||
- from_pretrained
|
||||
show_root_heading: true
|
||||
|
||||
::: llama_cpp.LlamaGrammar
|
||||
|
|
|
@ -4,6 +4,8 @@ import os
|
|||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import fnmatch
|
||||
import multiprocessing
|
||||
from typing import (
|
||||
List,
|
||||
|
@ -16,6 +18,7 @@ from typing import (
|
|||
Callable,
|
||||
)
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import ctypes
|
||||
|
||||
|
@ -29,10 +32,7 @@ from .llama_cache import (
|
|||
LlamaDiskCache, # type: ignore
|
||||
LlamaRAMCache, # type: ignore
|
||||
)
|
||||
from .llama_tokenizer import (
|
||||
BaseLlamaTokenizer,
|
||||
LlamaTokenizer
|
||||
)
|
||||
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
import llama_cpp.llama_chat_format as llama_chat_format
|
||||
|
||||
|
@ -50,9 +50,7 @@ from ._internals import (
|
|||
_LlamaSamplingContext, # type: ignore
|
||||
)
|
||||
from ._logger import set_verbose
|
||||
from ._utils import (
|
||||
suppress_stdout_stderr
|
||||
)
|
||||
from ._utils import suppress_stdout_stderr
|
||||
|
||||
|
||||
class Llama:
|
||||
|
@ -189,7 +187,11 @@ class Llama:
|
|||
Llama.__backend_initialized = True
|
||||
|
||||
if isinstance(numa, bool):
|
||||
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
||||
self.numa = (
|
||||
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
|
||||
if numa
|
||||
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
||||
)
|
||||
else:
|
||||
self.numa = numa
|
||||
|
||||
|
@ -246,9 +248,9 @@ class Llama:
|
|||
else:
|
||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||
|
||||
self._kv_overrides_array[
|
||||
-1
|
||||
].key = b"\0" # ensure sentinel element is zeroed
|
||||
self._kv_overrides_array[-1].key = (
|
||||
b"\0" # ensure sentinel element is zeroed
|
||||
)
|
||||
self.model_params.kv_overrides = self._kv_overrides_array
|
||||
|
||||
self.n_batch = min(n_ctx, n_batch) # ???
|
||||
|
@ -256,7 +258,7 @@ class Llama:
|
|||
self.n_threads_batch = n_threads_batch or max(
|
||||
multiprocessing.cpu_count() // 2, 1
|
||||
)
|
||||
|
||||
|
||||
# Context Params
|
||||
self.context_params = llama_cpp.llama_context_default_params()
|
||||
self.context_params.seed = seed
|
||||
|
@ -289,7 +291,9 @@ class Llama:
|
|||
)
|
||||
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
||||
self.context_params.mul_mat_q = mul_mat_q
|
||||
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
|
||||
self.context_params.logits_all = (
|
||||
logits_all if draft_model is None else True
|
||||
) # Must be set to True for speculative decoding
|
||||
self.context_params.embedding = embedding
|
||||
self.context_params.offload_kqv = offload_kqv
|
||||
|
||||
|
@ -379,8 +383,14 @@ class Llama:
|
|||
if self.verbose:
|
||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
|
||||
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
|
||||
if (
|
||||
self.chat_format is None
|
||||
and self.chat_handler is None
|
||||
and "tokenizer.chat_template" in self.metadata
|
||||
):
|
||||
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
|
||||
self.metadata
|
||||
)
|
||||
|
||||
if chat_format is not None:
|
||||
self.chat_format = chat_format
|
||||
|
@ -406,9 +416,7 @@ class Llama:
|
|||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||
|
||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||
template=template,
|
||||
eos_token=eos_token,
|
||||
bos_token=bos_token
|
||||
template=template, eos_token=eos_token, bos_token=bos_token
|
||||
).to_chat_handler()
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None:
|
||||
|
@ -459,7 +467,9 @@ class Llama:
|
|||
"""
|
||||
return self.tokenizer_.tokenize(text, add_bos, special)
|
||||
|
||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
||||
def detokenize(
|
||||
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
|
||||
) -> bytes:
|
||||
"""Detokenize a list of tokens.
|
||||
|
||||
Args:
|
||||
|
@ -565,7 +575,7 @@ class Llama:
|
|||
logits[:] = (
|
||||
logits_processor(self._input_ids, logits)
|
||||
if idx is None
|
||||
else logits_processor(self._input_ids[:idx + 1], logits)
|
||||
else logits_processor(self._input_ids[: idx + 1], logits)
|
||||
)
|
||||
|
||||
sampling_params = _LlamaSamplingParams(
|
||||
|
@ -707,7 +717,9 @@ class Llama:
|
|||
|
||||
if self.draft_model is not None:
|
||||
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
||||
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
|
||||
draft_tokens = self.draft_model(
|
||||
self.input_ids[: self.n_tokens + len(tokens)]
|
||||
)
|
||||
tokens.extend(
|
||||
draft_tokens.astype(int)[
|
||||
: self._n_ctx - self.n_tokens - len(tokens)
|
||||
|
@ -792,6 +804,7 @@ class Llama:
|
|||
|
||||
# decode and fetch embeddings
|
||||
data: List[List[float]] = []
|
||||
|
||||
def decode_batch(n_seq: int):
|
||||
assert self._ctx.ctx is not None
|
||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||
|
@ -800,9 +813,9 @@ class Llama:
|
|||
|
||||
# store embeddings
|
||||
for i in range(n_seq):
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||
:n_embd
|
||||
]
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
|
||||
self._ctx.ctx, i
|
||||
)[:n_embd]
|
||||
if normalize:
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
embedding = [v / norm for v in embedding]
|
||||
|
@ -1669,12 +1682,13 @@ class Llama:
|
|||
"""
|
||||
try:
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
stream = kwargs.get("stream", False) # type: ignore
|
||||
|
||||
stream = kwargs.get("stream", False) # type: ignore
|
||||
assert isinstance(stream, bool)
|
||||
if stream:
|
||||
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||
else:
|
||||
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use create_chat_completion_openai_v1, you must install the openai package."
|
||||
|
@ -1866,7 +1880,88 @@ class Llama:
|
|||
break
|
||||
return longest_prefix
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
repo_id: str,
|
||||
filename: Optional[str],
|
||||
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
|
||||
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
||||
**kwargs: Any,
|
||||
) -> "Llama":
|
||||
"""Create a Llama model from a pretrained model name or path.
|
||||
This method requires the huggingface-hub package.
|
||||
You can install it with `pip install huggingface-hub`.
|
||||
|
||||
Args:
|
||||
repo_id: The model repo id.
|
||||
filename: A filename or glob pattern to match the model file in the repo.
|
||||
local_dir: The local directory to save the model to.
|
||||
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
|
||||
**kwargs: Additional keyword arguments to pass to the Llama constructor.
|
||||
|
||||
Returns:
|
||||
A Llama model."""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download, HfFileSystem
|
||||
from huggingface_hub.utils import validate_repo_id
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Llama.from_pretrained requires the huggingface-hub package. "
|
||||
"You can install it with `pip install huggingface-hub`."
|
||||
)
|
||||
|
||||
validate_repo_id(repo_id)
|
||||
|
||||
hffs = HfFileSystem()
|
||||
|
||||
files = [
|
||||
file["name"] if isinstance(file, dict) else file
|
||||
for file in hffs.ls(repo_id)
|
||||
]
|
||||
|
||||
# split each file into repo_id, subfolder, filename
|
||||
file_list: List[str] = []
|
||||
for file in files:
|
||||
rel_path = Path(file).relative_to(repo_id)
|
||||
file_list.append(str(rel_path))
|
||||
|
||||
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
|
||||
|
||||
if len(matching_files) == 0:
|
||||
raise ValueError(
|
||||
f"No file found in {repo_id} that match {filename}\n\n"
|
||||
f"Available Files:\n{json.dumps(file_list)}"
|
||||
)
|
||||
|
||||
if len(matching_files) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple files found in {repo_id} matching {filename}\n\n"
|
||||
f"Available Files:\n{json.dumps(files)}"
|
||||
)
|
||||
|
||||
(matching_file,) = matching_files
|
||||
|
||||
subfolder = str(Path(matching_file).parent)
|
||||
filename = Path(matching_file).name
|
||||
|
||||
local_dir = "."
|
||||
|
||||
# download the file
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=local_dir,
|
||||
filename=filename,
|
||||
subfolder=subfolder,
|
||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||
)
|
||||
|
||||
model_path = os.path.join(local_dir, filename)
|
||||
|
||||
return cls(
|
||||
model_path=model_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LlamaState:
|
||||
|
|
Loading…
Add table
Reference in a new issue