diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..f0ef5f7 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,51 @@ +# Define the image argument and provide a default value +ARG IMAGE=python:3-slim-bullseye + +# Use the image as specified +FROM ${IMAGE} + +# Re-declare the ARG after FROM +ARG IMAGE + +# Update and upgrade the existing packages +RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + ninja-build \ + build-essential + +RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette + +# Perform the conditional installations based on the image +RUN echo "Image: ${IMAGE}" && \ + if [ "${IMAGE}" = "python:3-slim-bullseye" ] ; then \ + echo "OpenBLAS install:" && \ + apt-get install -y --no-install-recommends libopenblas-dev && \ + LLAMA_OPENBLAS=1 pip install llama-cpp-python --verbose; \ +else \ + echo "CuBLAS install:" && \ + LLAMA_CUBLAS=1 pip install llama-cpp-python --verbose; \ +fi + +# Clean up apt cache +RUN rm -rf /var/lib/apt/lists/* + +# Set a working directory for better clarity +WORKDIR /app + +# Copy files to the app directory +RUN echo "Installing model...this can take some time..." +COPY ./model.bin /app/model.bin +COPY ./start_server.sh /app/start_server.sh + +# Make the server start script executable +RUN chmod +x /app/start_server.sh + +# Set environment variable for the host +ENV HOST=0.0.0.0 + +# Expose a port for the server +EXPOSE 8000 + +# Run the server start script +CMD ["/bin/sh", "/app/start_server.sh"] diff --git a/Dockerfile.cuda b/docker/Dockerfile.cuda_simple similarity index 100% rename from Dockerfile.cuda rename to docker/Dockerfile.cuda_simple diff --git a/Dockerfile b/docker/Dockerfile.openblas_simple similarity index 100% rename from Dockerfile rename to docker/Dockerfile.openblas_simple diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000..100bcbd --- /dev/null +++ b/docker/README.md @@ -0,0 +1,46 @@ +# Dockerfiles for building the llama-cpp-python server +- `Dockerfile.openblas_simple` - a simple Dockerfile for non-GPU OpenBLAS +- `Dockerfile.cuda_simple` - a simple Dockerfile for CUDA accelerated CuBLAS +- `hug_model.py` - a Python utility for interactively choosing and downloading the latest `5_1` quantized models from [huggingface.co/TheBloke]( https://huggingface.co/TheBloke) +- `Dockerfile` - a single OpenBLAS and CuBLAS combined Dockerfile that automatically installs a previously downloaded model `model.bin` + +# Get model from Hugging Face +`python3 ./hug_model.py` + +You should now have a model in the current directory and `model.bin` symlinked to it for the subsequent Docker build and copy step. e.g. +``` +docker $ ls -lh *.bin +-rw-rw-r-- 1 user user 4.8G May 23 18:30 .q5_1.bin +lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5_1.bin +``` +**Note #1:** Make sure you have enough disk space to download the model. As the model is then copied into the image you will need at least +**TWICE** as much disk space as the size of the model: + +| Model | Quantized size | +|------:|----------------:| +| 7B | 5 GB | +| 13B | 10 GB | +| 30B | 25 GB | +| 65B | 50 GB | + +**Note #2:** If you want to pass or tune additional parameters, customise `./start_server.sh` before running `docker build ...` + +# Install Docker Server + +**Note #3:** This was tested with Docker running on Linux. If you can get it working on Windows or MacOS, please update this `README.md` with a PR! + +[Install Docker Engine](https://docs.docker.com/engine/install) + +# Use OpenBLAS +Use if you don't have a NVidia GPU. Defaults to `python:3-slim-bullseye` Docker base image and OpenBLAS: +## Build: +`docker build --build-arg -t openblas .` +## Run: +`docker run --cap-add SYS_RESOURCE -t openblas` + +# Use CuBLAS +Requires a NVidia GPU with sufficient VRAM (approximately as much as the size above) and Docker NVidia support (see [container-toolkit/install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) +## Build: +`docker build --build-arg IMAGE=nvidia/cuda:12.1.1-devel-ubuntu22.04 -t cublas .` +## Run: +`docker run --cap-add SYS_RESOURCE -t cublas` diff --git a/docker/hug_model.py b/docker/hug_model.py new file mode 100644 index 0000000..848a1aa --- /dev/null +++ b/docker/hug_model.py @@ -0,0 +1,116 @@ +import requests +import json +import os +import struct + +def make_request(url, params=None): + print(f"Making request to {url}...") + response = requests.get(url, params=params) + if response.status_code == 200: + return json.loads(response.text) + else: + print(f"Request failed with status code {response.status_code}") + return None + +def check_magic_and_version(filename): + with open(filename, 'rb') as f: + # Read the first 6 bytes from the file + data = f.read(6) + + # Unpack the binary data, interpreting the first 4 bytes as a little-endian unsigned int + # and the next 2 bytes as a little-endian unsigned short + magic, version = struct.unpack('= 10485760: # 10 MB + print('.', end='', flush=True) + total_downloaded = 0 + print("\nDownload complete.") + + # Creating a symbolic link from destination to "model.bin" + if os.path.isfile("model.bin"): + os.remove("model.bin") # remove the existing link if any + os.symlink(destination, "model.bin") + else: + print(f"Download failed with status code {response.status_code}") + +def get_user_choice(model_list): + # Print the enumerated list + print("\n") + for i, (model_id, rfilename) in enumerate(model_list): + print(f"{i+1}: Model ID: {model_id}, RFilename: {rfilename}") + + # Get user's choice + choice = input("Choose a model to download by entering the corresponding number: ") + try: + index = int(choice) - 1 + if 0 <= index < len(model_list): + # Return the chosen model + return model_list[index] + else: + print("Invalid choice.") + except ValueError: + print("Invalid input. Please enter a number corresponding to a model.") + except IndexError: + print("Invalid choice. Index out of range.") + + return None + +import argparse + +def main(): + # Create an argument parser + parser = argparse.ArgumentParser(description='Process the model version.') + parser.add_argument('-v', '--version', type=int, default=0x0003, + help='an integer for the version to be used') + + # Parse the arguments + args = parser.parse_args() + + # Define the parameters + params = { + "author": "TheBloke", # Filter by author + "tags": "llama" + } + + models = make_request('https://huggingface.co/api/models', params=params) + if models is None: + return + + model_list = [] + # Iterate over the models + for model in models: + model_id = model['id'] + model_info = make_request(f'https://huggingface.co/api/models/{model_id}') + if model_info is None: + continue + + for sibling in model_info.get('siblings', []): + rfilename = sibling.get('rfilename') + if rfilename and 'q5_1' in rfilename: + model_list.append((model_id, rfilename)) + + model_choice = get_user_choice(model_list) + if model_choice is not None: + model_id, rfilename = model_choice + url = f"https://huggingface.co/{model_id}/resolve/main/{rfilename}" + download_file(url, rfilename) + _, version = check_magic_and_version(rfilename) + if version != args.version: + print(f"Warning: Expected version {args.version}, but found different version in the file.") + +if __name__ == '__main__': + main() diff --git a/docker/start_server.sh b/docker/start_server.sh new file mode 100755 index 0000000..176bd87 --- /dev/null +++ b/docker/start_server.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +# For mmap support +ulimit -l unlimited + +if [ "$IMAGE" = "python:3-slim-bullseye" ]; then + python3 -B -m llama_cpp.server --model /app/model.bin +else + # You may have to reduce --n_gpu_layers=1000 to 20 or less if you don't have enough VRAM + python3 -B -m llama_cpp.server --model /app/model.bin --n_gpu_layers=1000 +fi diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 43fa9c7..b43dfe7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,17 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple +from typing import ( + List, + Optional, + Union, + Generator, + Sequence, + Iterator, + Deque, + Tuple, + Callable, +) from collections import deque, OrderedDict from . import llama_cpp @@ -72,6 +82,24 @@ class LlamaState: self.llama_state_size = llama_state_size +LogitsProcessor = Callable[[List[int], List[float]], List[float]] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[List[int], List[float]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__(self, input_ids: List[int], logits: List[float]) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -314,6 +342,7 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -326,6 +355,10 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] + + if logits_processor is not None: + logits = logits_processor(list(self.eval_tokens), logits) + nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): @@ -434,6 +467,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -466,6 +500,7 @@ class Llama: mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, + logits_processor=logits_processor, ) def generate( @@ -482,6 +517,8 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -539,7 +576,12 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + logits_processor=logits_processor, ) + if stopping_criteria is not None and stopping_criteria( + list(self.eval_tokens), self.eval_logits[-1] + ): + return tokens_or_none = yield token tokens = [token] if tokens_or_none is not None: @@ -637,6 +679,7 @@ class Llama: model: Optional[str] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None + completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) completion_tokens: List[int] = [] @@ -1334,6 +1377,11 @@ class Llama: assert self.ctx is not None return llama_cpp.llama_n_vocab(self.ctx) + def tokenizer(self) -> "LlamaTokenizer": + """Return the tokenizer for this model.""" + assert self.ctx is not None + return LlamaTokenizer(self) + @staticmethod def token_eos() -> int: """Return the end-of-sequence token.""" @@ -1364,3 +1412,18 @@ class Llama: else: break return longest_prefix + + +class LlamaTokenizer: + def __init__(self, llama: Llama): + self.llama = llama + + def encode(self, text: str) -> List[int]: + return self.llama.tokenize(text.encode("utf-8", errors="ignore")) + + def decode(self, tokens: List[int]) -> str: + return self.llama.detokenize(tokens).decode("utf-8", errors="ignore") + + @classmethod + def from_ggml_file(cls, path: str) -> "LlamaTokenizer": + return cls(Llama(model_path=path, vocab_only=True))