From 0f8aa4ab5c63fbb02630242bfe03d0349929b913 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 21 Feb 2024 16:25:10 -0500 Subject: [PATCH] feat: Pull models directly from huggingface (#1206) * Add from_pretrained method to Llama class * Update docs * Merge filename and pattern --- README.md | 13 ++++ docs/api-reference.md | 1 + llama_cpp/llama.py | 149 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 7da8e5f..03c95e8 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,19 @@ Below is a short example demonstrating how to use the high-level API to for basi Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class. +## Pulling models from Hugging Face + +You can pull `Llama` models from Hugging Face using the `from_pretrained` method. +You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`). + +```python +llama = Llama.from_pretrained( + repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", + filename="*q8_0.gguf", + verbose=False +) +``` + ### Chat Completion The high-level API also provides a simple interface for chat completion. diff --git a/docs/api-reference.md b/docs/api-reference.md index 562410f..c635e11 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -26,6 +26,7 @@ High-level Python bindings for llama.cpp. - load_state - token_bos - token_eos + - from_pretrained show_root_heading: true ::: llama_cpp.LlamaGrammar diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 30cab0a..65902c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,6 +4,8 @@ import os import sys import uuid import time +import json +import fnmatch import multiprocessing from typing import ( List, @@ -16,6 +18,7 @@ from typing import ( Callable, ) from collections import deque +from pathlib import Path import ctypes @@ -29,10 +32,7 @@ from .llama_cache import ( LlamaDiskCache, # type: ignore LlamaRAMCache, # type: ignore ) -from .llama_tokenizer import ( - BaseLlamaTokenizer, - LlamaTokenizer -) +from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format @@ -50,9 +50,7 @@ from ._internals import ( _LlamaSamplingContext, # type: ignore ) from ._logger import set_verbose -from ._utils import ( - suppress_stdout_stderr -) +from ._utils import suppress_stdout_stderr class Llama: @@ -189,7 +187,11 @@ class Llama: Llama.__backend_initialized = True if isinstance(numa, bool): - self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED + self.numa = ( + llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE + if numa + else llama_cpp.GGML_NUMA_STRATEGY_DISABLED + ) else: self.numa = numa @@ -246,9 +248,9 @@ class Llama: else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array[ - -1 - ].key = b"\0" # ensure sentinel element is zeroed + self._kv_overrides_array[-1].key = ( + b"\0" # ensure sentinel element is zeroed + ) self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? @@ -256,7 +258,7 @@ class Llama: self.n_threads_batch = n_threads_batch or max( multiprocessing.cpu_count() // 2, 1 ) - + # Context Params self.context_params = llama_cpp.llama_context_default_params() self.context_params.seed = seed @@ -289,7 +291,9 @@ class Llama: ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 self.context_params.mul_mat_q = mul_mat_q - self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding + self.context_params.logits_all = ( + logits_all if draft_model is None else True + ) # Must be set to True for speculative decoding self.context_params.embedding = embedding self.context_params.offload_kqv = offload_kqv @@ -379,8 +383,14 @@ class Llama: if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) - if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata: - chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata) + if ( + self.chat_format is None + and self.chat_handler is None + and "tokenizer.chat_template" in self.metadata + ): + chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata( + self.metadata + ) if chat_format is not None: self.chat_format = chat_format @@ -406,9 +416,7 @@ class Llama: print(f"Using chat bos_token: {bos_token}", file=sys.stderr) self.chat_handler = llama_chat_format.Jinja2ChatFormatter( - template=template, - eos_token=eos_token, - bos_token=bos_token + template=template, eos_token=eos_token, bos_token=bos_token ).to_chat_handler() if self.chat_format is None and self.chat_handler is None: @@ -459,7 +467,9 @@ class Llama: """ return self.tokenizer_.tokenize(text, add_bos, special) - def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes: + def detokenize( + self, tokens: List[int], prev_tokens: Optional[List[int]] = None + ) -> bytes: """Detokenize a list of tokens. Args: @@ -565,7 +575,7 @@ class Llama: logits[:] = ( logits_processor(self._input_ids, logits) if idx is None - else logits_processor(self._input_ids[:idx + 1], logits) + else logits_processor(self._input_ids[: idx + 1], logits) ) sampling_params = _LlamaSamplingParams( @@ -707,7 +717,9 @@ class Llama: if self.draft_model is not None: self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens - draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)]) + draft_tokens = self.draft_model( + self.input_ids[: self.n_tokens + len(tokens)] + ) tokens.extend( draft_tokens.astype(int)[ : self._n_ctx - self.n_tokens - len(tokens) @@ -792,6 +804,7 @@ class Llama: # decode and fetch embeddings data: List[List[float]] = [] + def decode_batch(n_seq: int): assert self._ctx.ctx is not None llama_cpp.llama_kv_cache_clear(self._ctx.ctx) @@ -800,9 +813,9 @@ class Llama: # store embeddings for i in range(n_seq): - embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ - :n_embd - ] + embedding: List[float] = llama_cpp.llama_get_embeddings_ith( + self._ctx.ctx, i + )[:n_embd] if normalize: norm = float(np.linalg.norm(embedding)) embedding = [v / norm for v in embedding] @@ -1669,12 +1682,13 @@ class Llama: """ try: from openai.types.chat import ChatCompletion, ChatCompletionChunk - stream = kwargs.get("stream", False) # type: ignore + + stream = kwargs.get("stream", False) # type: ignore assert isinstance(stream, bool) if stream: - return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore + return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore else: - return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore + return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore except ImportError: raise ImportError( "To use create_chat_completion_openai_v1, you must install the openai package." @@ -1866,7 +1880,88 @@ class Llama: break return longest_prefix + @classmethod + def from_pretrained( + cls, + repo_id: str, + filename: Optional[str], + local_dir: Optional[Union[str, os.PathLike[str]]] = ".", + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + **kwargs: Any, + ) -> "Llama": + """Create a Llama model from a pretrained model name or path. + This method requires the huggingface-hub package. + You can install it with `pip install huggingface-hub`. + Args: + repo_id: The model repo id. + filename: A filename or glob pattern to match the model file in the repo. + local_dir: The local directory to save the model to. + local_dir_use_symlinks: Whether to use symlinks when downloading the model. + **kwargs: Additional keyword arguments to pass to the Llama constructor. + + Returns: + A Llama model.""" + try: + from huggingface_hub import hf_hub_download, HfFileSystem + from huggingface_hub.utils import validate_repo_id + except ImportError: + raise ImportError( + "Llama.from_pretrained requires the huggingface-hub package. " + "You can install it with `pip install huggingface-hub`." + ) + + validate_repo_id(repo_id) + + hffs = HfFileSystem() + + files = [ + file["name"] if isinstance(file, dict) else file + for file in hffs.ls(repo_id) + ] + + # split each file into repo_id, subfolder, filename + file_list: List[str] = [] + for file in files: + rel_path = Path(file).relative_to(repo_id) + file_list.append(str(rel_path)) + + matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore + + if len(matching_files) == 0: + raise ValueError( + f"No file found in {repo_id} that match {filename}\n\n" + f"Available Files:\n{json.dumps(file_list)}" + ) + + if len(matching_files) > 1: + raise ValueError( + f"Multiple files found in {repo_id} matching {filename}\n\n" + f"Available Files:\n{json.dumps(files)}" + ) + + (matching_file,) = matching_files + + subfolder = str(Path(matching_file).parent) + filename = Path(matching_file).name + + local_dir = "." + + # download the file + hf_hub_download( + repo_id=repo_id, + local_dir=local_dir, + filename=filename, + subfolder=subfolder, + local_dir_use_symlinks=local_dir_use_symlinks, + ) + + model_path = os.path.join(local_dir, filename) + + return cls( + model_path=model_path, + **kwargs, + ) class LlamaState: