feat: Pull models directly from huggingface (#1206)

* Add from_pretrained method to Llama class

* Update docs

* Merge filename and pattern
This commit is contained in:
Andrei 2024-02-21 16:25:10 -05:00 committed by GitHub
parent e42f62c247
commit 0f8aa4ab5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 136 additions and 27 deletions

View file

@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
## Pulling models from Hugging Face
You can pull `Llama` models from Hugging Face using the `from_pretrained` method.
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
```python
llama = Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q8_0.gguf",
verbose=False
)
```
### Chat Completion
The high-level API also provides a simple interface for chat completion.

View file

@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp.
- load_state
- token_bos
- token_eos
- from_pretrained
show_root_heading: true
::: llama_cpp.LlamaGrammar

View file

@ -4,6 +4,8 @@ import os
import sys
import uuid
import time
import json
import fnmatch
import multiprocessing
from typing import (
List,
@ -16,6 +18,7 @@ from typing import (
Callable,
)
from collections import deque
from pathlib import Path
import ctypes
@ -29,10 +32,7 @@ from .llama_cache import (
LlamaDiskCache, # type: ignore
LlamaRAMCache, # type: ignore
)
from .llama_tokenizer import (
BaseLlamaTokenizer,
LlamaTokenizer
)
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
@ -50,9 +50,7 @@ from ._internals import (
_LlamaSamplingContext, # type: ignore
)
from ._logger import set_verbose
from ._utils import (
suppress_stdout_stderr
)
from ._utils import suppress_stdout_stderr
class Llama:
@ -189,7 +187,11 @@ class Llama:
Llama.__backend_initialized = True
if isinstance(numa, bool):
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
self.numa = (
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
)
else:
self.numa = numa
@ -246,9 +248,9 @@ class Llama:
else:
raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array[
-1
].key = b"\0" # ensure sentinel element is zeroed
self._kv_overrides_array[-1].key = (
b"\0" # ensure sentinel element is zeroed
)
self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ???
@ -256,7 +258,7 @@ class Llama:
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
@ -289,7 +291,9 @@ class Llama:
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
self.context_params.embedding = embedding
self.context_params.offload_kqv = offload_kqv
@ -379,8 +383,14 @@ class Llama:
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
if (
self.chat_format is None
and self.chat_handler is None
and "tokenizer.chat_template" in self.metadata
):
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
self.metadata
)
if chat_format is not None:
self.chat_format = chat_format
@ -406,9 +416,7 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
template=template, eos_token=eos_token, bos_token=bos_token
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
@ -459,7 +467,9 @@ class Llama:
"""
return self.tokenizer_.tokenize(text, add_bos, special)
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
"""Detokenize a list of tokens.
Args:
@ -565,7 +575,7 @@ class Llama:
logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx + 1], logits)
else logits_processor(self._input_ids[: idx + 1], logits)
)
sampling_params = _LlamaSamplingParams(
@ -707,7 +717,9 @@ class Llama:
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
draft_tokens = self.draft_model(
self.input_ids[: self.n_tokens + len(tokens)]
)
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
@ -792,6 +804,7 @@ class Llama:
# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(n_seq: int):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
@ -800,9 +813,9 @@ class Llama:
# store embeddings
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
self._ctx.ctx, i
)[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
@ -1669,12 +1682,13 @@ class Llama:
"""
try:
from openai.types.chat import ChatCompletion, ChatCompletionChunk
stream = kwargs.get("stream", False) # type: ignore
stream = kwargs.get("stream", False) # type: ignore
assert isinstance(stream, bool)
if stream:
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
else:
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
except ImportError:
raise ImportError(
"To use create_chat_completion_openai_v1, you must install the openai package."
@ -1866,7 +1880,88 @@ class Llama:
break
return longest_prefix
@classmethod
def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
**kwargs: Any,
) -> "Llama":
"""Create a Llama model from a pretrained model name or path.
This method requires the huggingface-hub package.
You can install it with `pip install huggingface-hub`.
Args:
repo_id: The model repo id.
filename: A filename or glob pattern to match the model file in the repo.
local_dir: The local directory to save the model to.
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
**kwargs: Additional keyword arguments to pass to the Llama constructor.
Returns:
A Llama model."""
try:
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.utils import validate_repo_id
except ImportError:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
)
validate_repo_id(repo_id)
hffs = HfFileSystem()
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id)
]
# split each file into repo_id, subfolder, filename
file_list: List[str] = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
if len(matching_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {filename}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)
if len(matching_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {filename}\n\n"
f"Available Files:\n{json.dumps(files)}"
)
(matching_file,) = matching_files
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
local_dir = "."
# download the file
hf_hub_download(
repo_id=repo_id,
local_dir=local_dir,
filename=filename,
subfolder=subfolder,
local_dir_use_symlinks=local_dir_use_symlinks,
)
model_path = os.path.join(local_dir, filename)
return cls(
model_path=model_path,
**kwargs,
)
class LlamaState: