feat: Pull models directly from huggingface ()

* Add from_pretrained method to Llama class

* Update docs

* Merge filename and pattern
This commit is contained in:
Andrei 2024-02-21 16:25:10 -05:00 committed by GitHub
parent e42f62c247
commit 0f8aa4ab5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 136 additions and 27 deletions

View file

@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class. Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
## Pulling models from Hugging Face
You can pull `Llama` models from Hugging Face using the `from_pretrained` method.
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
```python
llama = Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q8_0.gguf",
verbose=False
)
```
### Chat Completion ### Chat Completion
The high-level API also provides a simple interface for chat completion. The high-level API also provides a simple interface for chat completion.

View file

@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp.
- load_state - load_state
- token_bos - token_bos
- token_eos - token_eos
- from_pretrained
show_root_heading: true show_root_heading: true
::: llama_cpp.LlamaGrammar ::: llama_cpp.LlamaGrammar

View file

@ -4,6 +4,8 @@ import os
import sys import sys
import uuid import uuid
import time import time
import json
import fnmatch
import multiprocessing import multiprocessing
from typing import ( from typing import (
List, List,
@ -16,6 +18,7 @@ from typing import (
Callable, Callable,
) )
from collections import deque from collections import deque
from pathlib import Path
import ctypes import ctypes
@ -29,10 +32,7 @@ from .llama_cache import (
LlamaDiskCache, # type: ignore LlamaDiskCache, # type: ignore
LlamaRAMCache, # type: ignore LlamaRAMCache, # type: ignore
) )
from .llama_tokenizer import ( from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
BaseLlamaTokenizer,
LlamaTokenizer
)
import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format import llama_cpp.llama_chat_format as llama_chat_format
@ -50,9 +50,7 @@ from ._internals import (
_LlamaSamplingContext, # type: ignore _LlamaSamplingContext, # type: ignore
) )
from ._logger import set_verbose from ._logger import set_verbose
from ._utils import ( from ._utils import suppress_stdout_stderr
suppress_stdout_stderr
)
class Llama: class Llama:
@ -189,7 +187,11 @@ class Llama:
Llama.__backend_initialized = True Llama.__backend_initialized = True
if isinstance(numa, bool): if isinstance(numa, bool):
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED self.numa = (
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
)
else: else:
self.numa = numa self.numa = numa
@ -246,9 +248,9 @@ class Llama:
else: else:
raise ValueError(f"Unknown value type for {k}: {v}") raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array[ self._kv_overrides_array[-1].key = (
-1 b"\0" # ensure sentinel element is zeroed
].key = b"\0" # ensure sentinel element is zeroed )
self.model_params.kv_overrides = self._kv_overrides_array self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ??? self.n_batch = min(n_ctx, n_batch) # ???
@ -256,7 +258,7 @@ class Llama:
self.n_threads_batch = n_threads_batch or max( self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1 multiprocessing.cpu_count() // 2, 1
) )
# Context Params # Context Params
self.context_params = llama_cpp.llama_context_default_params() self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed self.context_params.seed = seed
@ -289,7 +291,9 @@ class Llama:
) )
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
self.context_params.embedding = embedding self.context_params.embedding = embedding
self.context_params.offload_kqv = offload_kqv self.context_params.offload_kqv = offload_kqv
@ -379,8 +383,14 @@ class Llama:
if self.verbose: if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr) print(f"Model metadata: {self.metadata}", file=sys.stderr)
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata: if (
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata) self.chat_format is None
and self.chat_handler is None
and "tokenizer.chat_template" in self.metadata
):
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
self.metadata
)
if chat_format is not None: if chat_format is not None:
self.chat_format = chat_format self.chat_format = chat_format
@ -406,9 +416,7 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr) print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter( self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template, template=template, eos_token=eos_token, bos_token=bos_token
eos_token=eos_token,
bos_token=bos_token
).to_chat_handler() ).to_chat_handler()
if self.chat_format is None and self.chat_handler is None: if self.chat_format is None and self.chat_handler is None:
@ -459,7 +467,9 @@ class Llama:
""" """
return self.tokenizer_.tokenize(text, add_bos, special) return self.tokenizer_.tokenize(text, add_bos, special)
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes: def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
"""Detokenize a list of tokens. """Detokenize a list of tokens.
Args: Args:
@ -565,7 +575,7 @@ class Llama:
logits[:] = ( logits[:] = (
logits_processor(self._input_ids, logits) logits_processor(self._input_ids, logits)
if idx is None if idx is None
else logits_processor(self._input_ids[:idx + 1], logits) else logits_processor(self._input_ids[: idx + 1], logits)
) )
sampling_params = _LlamaSamplingParams( sampling_params = _LlamaSamplingParams(
@ -707,7 +717,9 @@ class Llama:
if self.draft_model is not None: if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)]) draft_tokens = self.draft_model(
self.input_ids[: self.n_tokens + len(tokens)]
)
tokens.extend( tokens.extend(
draft_tokens.astype(int)[ draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens) : self._n_ctx - self.n_tokens - len(tokens)
@ -792,6 +804,7 @@ class Llama:
# decode and fetch embeddings # decode and fetch embeddings
data: List[List[float]] = [] data: List[List[float]] = []
def decode_batch(n_seq: int): def decode_batch(n_seq: int):
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx) llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
@ -800,9 +813,9 @@ class Llama:
# store embeddings # store embeddings
for i in range(n_seq): for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
:n_embd self._ctx.ctx, i
] )[:n_embd]
if normalize: if normalize:
norm = float(np.linalg.norm(embedding)) norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding] embedding = [v / norm for v in embedding]
@ -1669,12 +1682,13 @@ class Llama:
""" """
try: try:
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
stream = kwargs.get("stream", False) # type: ignore
stream = kwargs.get("stream", False) # type: ignore
assert isinstance(stream, bool) assert isinstance(stream, bool)
if stream: if stream:
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
else: else:
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"To use create_chat_completion_openai_v1, you must install the openai package." "To use create_chat_completion_openai_v1, you must install the openai package."
@ -1866,7 +1880,88 @@ class Llama:
break break
return longest_prefix return longest_prefix
@classmethod
def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = ".",
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
**kwargs: Any,
) -> "Llama":
"""Create a Llama model from a pretrained model name or path.
This method requires the huggingface-hub package.
You can install it with `pip install huggingface-hub`.
Args:
repo_id: The model repo id.
filename: A filename or glob pattern to match the model file in the repo.
local_dir: The local directory to save the model to.
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
**kwargs: Additional keyword arguments to pass to the Llama constructor.
Returns:
A Llama model."""
try:
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.utils import validate_repo_id
except ImportError:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
)
validate_repo_id(repo_id)
hffs = HfFileSystem()
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id)
]
# split each file into repo_id, subfolder, filename
file_list: List[str] = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
if len(matching_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {filename}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)
if len(matching_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {filename}\n\n"
f"Available Files:\n{json.dumps(files)}"
)
(matching_file,) = matching_files
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
local_dir = "."
# download the file
hf_hub_download(
repo_id=repo_id,
local_dir=local_dir,
filename=filename,
subfolder=subfolder,
local_dir_use_symlinks=local_dir_use_symlinks,
)
model_path = os.path.join(local_dir, filename)
return cls(
model_path=model_path,
**kwargs,
)
class LlamaState: class LlamaState: