From eaff7a8678f655a41d978ba66cc0c805c8430684 Mon Sep 17 00:00:00 2001 From: Gary Mulder Date: Tue, 23 May 2023 19:26:40 +0000 Subject: [PATCH 01/11] Initial commit of auto docker --- docker/Dockerfile | 51 ++++++++++++++++++ docker/README.md | 33 ++++++++++++ docker/hug_model.py | 119 +++++++++++++++++++++++++++++++++++++++++ docker/start_server.sh | 11 ++++ 4 files changed, 214 insertions(+) create mode 100644 docker/Dockerfile create mode 100644 docker/README.md create mode 100644 docker/hug_model.py create mode 100755 docker/start_server.sh diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..f0ef5f7 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,51 @@ +# Define the image argument and provide a default value +ARG IMAGE=python:3-slim-bullseye + +# Use the image as specified +FROM ${IMAGE} + +# Re-declare the ARG after FROM +ARG IMAGE + +# Update and upgrade the existing packages +RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + ninja-build \ + build-essential + +RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette + +# Perform the conditional installations based on the image +RUN echo "Image: ${IMAGE}" && \ + if [ "${IMAGE}" = "python:3-slim-bullseye" ] ; then \ + echo "OpenBLAS install:" && \ + apt-get install -y --no-install-recommends libopenblas-dev && \ + LLAMA_OPENBLAS=1 pip install llama-cpp-python --verbose; \ +else \ + echo "CuBLAS install:" && \ + LLAMA_CUBLAS=1 pip install llama-cpp-python --verbose; \ +fi + +# Clean up apt cache +RUN rm -rf /var/lib/apt/lists/* + +# Set a working directory for better clarity +WORKDIR /app + +# Copy files to the app directory +RUN echo "Installing model...this can take some time..." +COPY ./model.bin /app/model.bin +COPY ./start_server.sh /app/start_server.sh + +# Make the server start script executable +RUN chmod +x /app/start_server.sh + +# Set environment variable for the host +ENV HOST=0.0.0.0 + +# Expose a port for the server +EXPOSE 8000 + +# Run the server start script +CMD ["/bin/sh", "/app/start_server.sh"] diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000..445f264 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,33 @@ +# Get model from Hugging Face +`python3 ./hug_model.py` + +You should now have a model in the current directory and model.bin symlinked to it for the subsequent Docker build and copy step. e.g. +``` +docker $ ls -lh *.bin +-rw-rw-r-- 1 user user 4.8G May 23 18:30 llama-7b.ggmlv3.q5_1.bin +lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5_1.bin +``` +- Note #1: Make sure you have enough disk space to d/l the model. As the model is then copied into the image you will need at least +**TWICE** as much disk space as the size of the model: + +| Model | Quantized size | +|------:|----------------:| +| 7B | 5 GB | +| 13B | 10 GB | +| 30B | 25 GB | +| 65B | 50 GB | + +- Note #2: If you want to pass or tune additional parameters, customise `./start_server.sh` before running `docker build ...` + +# Use OpenBLAS (No NVidia GPU, defaults to `python:3-slim-bullseye` Docker base image) +## Build: +`docker build --build-arg -t openblas .` +## Run: +`docker run --cap-add SYS_RESOURCE -t openblas` + +# Use CuBLAS +Requires NVidia GPU and Docker NVidia support (see https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) +## Build: +`docker build --build-arg IMAGE=nvidia/cuda:12.1.1-devel-ubuntu22.04 -t opencuda .` +## Run: +`docker run --cap-add SYS_RESOURCE -t cublas` diff --git a/docker/hug_model.py b/docker/hug_model.py new file mode 100644 index 0000000..476f53c --- /dev/null +++ b/docker/hug_model.py @@ -0,0 +1,119 @@ +import requests +import json +import os +import struct + +def make_request(url, params=None): + print(f"Making request to {url}...") + response = requests.get(url, params=params) + if response.status_code == 200: + return json.loads(response.text) + else: + print(f"Request failed with status code {response.status_code}") + return None + +def check_magic_and_version(filename): + with open(filename, 'rb') as f: + # Read the first 6 bytes from the file + data = f.read(6) + + # Unpack the binary data, interpreting the first 4 bytes as a little-endian unsigned int + # and the next 2 bytes as a little-endian unsigned short + magic, version = struct.unpack('= 10485760: # 10 MB + print('.', end='', flush=True) + total_downloaded = 0 + print("\nDownload complete.") + + # Creating a symbolic link from destination to "model.bin" + if os.path.isfile("model.bin"): + os.remove("model.bin") # remove the existing link if any + os.symlink(destination, "model.bin") + else: + print(f"Download failed with status code {response.status_code}") + +def get_user_choice(model_list): + # Print the enumerated list + print("\n") + for i, (model_id, rfilename) in enumerate(model_list): + print(f"{i+1}: Model ID: {model_id}, RFilename: {rfilename}") + + # Get user's choice + choice = input("Choose a model to download by entering the corresponding number: ") + try: + index = int(choice) - 1 + if 0 <= index < len(model_list): + # Return the chosen model + return model_list[index] + else: + print("Invalid choice.") + except ValueError: + print("Invalid input. Please enter a number corresponding to a model.") + except IndexError: + print("Invalid choice. Index out of range.") + + return None + +import argparse + +def main(): + # Create an argument parser + parser = argparse.ArgumentParser(description='Process the model version.') + parser.add_argument('-v', '--version', type=int, default=0x0003, + help='an integer for the version to be used') + + # Parse the arguments + args = parser.parse_args() + + # Define the parameters + params = { + "author": "TheBloke", # Filter by author + "tags": "llama" + } + + models = make_request('https://huggingface.co/api/models', params=params) + if models is None: + return + + model_list = [] + # Iterate over the models + for model in models: + model_id = model['id'] + model_info = make_request(f'https://huggingface.co/api/models/{model_id}') + if model_info is None: + continue + + for sibling in model_info.get('siblings', []): + rfilename = sibling.get('rfilename') + if rfilename and 'q5_1' in rfilename: + model_list.append((model_id, rfilename)) + + model_choice = get_user_choice(model_list) + if model_choice is not None: + model_id, rfilename = model_choice + url = f"https://huggingface.co/{model_id}/resolve/main/{rfilename}" + download_file(url, rfilename) + _, version = check_magic_and_version(rfilename) + if version != args.version: + print(f"Warning: Expected version {args.version}, but found different version in the file.") + +if __name__ == '__main__': + main() diff --git a/docker/start_server.sh b/docker/start_server.sh new file mode 100755 index 0000000..176bd87 --- /dev/null +++ b/docker/start_server.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +# For mmap support +ulimit -l unlimited + +if [ "$IMAGE" = "python:3-slim-bullseye" ]; then + python3 -B -m llama_cpp.server --model /app/model.bin +else + # You may have to reduce --n_gpu_layers=1000 to 20 or less if you don't have enough VRAM + python3 -B -m llama_cpp.server --model /app/model.bin --n_gpu_layers=1000 +fi From 70f629a72fe1dae576988a8107f683c66c887d7f Mon Sep 17 00:00:00 2001 From: Gary Mulder Date: Tue, 23 May 2023 20:36:21 +0100 Subject: [PATCH 02/11] Update README.md --- docker/README.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/docker/README.md b/docker/README.md index 445f264..3a538af 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,13 +1,13 @@ # Get model from Hugging Face `python3 ./hug_model.py` -You should now have a model in the current directory and model.bin symlinked to it for the subsequent Docker build and copy step. e.g. +You should now have a model in the current directory and `model.bin` symlinked to it for the subsequent Docker build and copy step. e.g. ``` docker $ ls -lh *.bin --rw-rw-r-- 1 user user 4.8G May 23 18:30 llama-7b.ggmlv3.q5_1.bin +-rw-rw-r-- 1 user user 4.8G May 23 18:30 .q5_1.bin lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5_1.bin ``` -- Note #1: Make sure you have enough disk space to d/l the model. As the model is then copied into the image you will need at least +**Note #1:** Make sure you have enough disk space to d/l the model. As the model is then copied into the image you will need at least **TWICE** as much disk space as the size of the model: | Model | Quantized size | @@ -17,16 +17,23 @@ lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5 | 30B | 25 GB | | 65B | 50 GB | -- Note #2: If you want to pass or tune additional parameters, customise `./start_server.sh` before running `docker build ...` +**Note #2:** If you want to pass or tune additional parameters, customise `./start_server.sh` before running `docker build ...` -# Use OpenBLAS (No NVidia GPU, defaults to `python:3-slim-bullseye` Docker base image) +# Install Docker Server + +**Note #3:** This was tested with Docker running on Linux. If you can get it working on Windows or MacOS, please update this README with a PR! + +[Install Docker Engine](https://docs.docker.com/engine/install) + +# Use OpenBLAS +No NVidia GPU, defaults to `python:3-slim-bullseye` Docker base image and OpenBlAS: ## Build: `docker build --build-arg -t openblas .` ## Run: `docker run --cap-add SYS_RESOURCE -t openblas` # Use CuBLAS -Requires NVidia GPU and Docker NVidia support (see https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) +Requires NVidia GPU and Docker NVidia support (see [container-toolkit/install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) ## Build: `docker build --build-arg IMAGE=nvidia/cuda:12.1.1-devel-ubuntu22.04 -t opencuda .` ## Run: From ed19071ef8439d876bde415852cd53ba0a863ebd Mon Sep 17 00:00:00 2001 From: Gary Mulder Date: Tue, 23 May 2023 19:38:37 +0000 Subject: [PATCH 03/11] Renamed and moved old Dockerfiles --- Dockerfile.cuda => docker/Dockerfile.cuda_simple | 0 Dockerfile => docker/Dockerfile.openblas_simple | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename Dockerfile.cuda => docker/Dockerfile.cuda_simple (100%) rename Dockerfile => docker/Dockerfile.openblas_simple (100%) diff --git a/Dockerfile.cuda b/docker/Dockerfile.cuda_simple similarity index 100% rename from Dockerfile.cuda rename to docker/Dockerfile.cuda_simple diff --git a/Dockerfile b/docker/Dockerfile.openblas_simple similarity index 100% rename from Dockerfile rename to docker/Dockerfile.openblas_simple From ec44bdad614c68b3b3f904ff04ecc68ea158ff3e Mon Sep 17 00:00:00 2001 From: Gary Mulder Date: Tue, 23 May 2023 20:50:39 +0100 Subject: [PATCH 04/11] Update README.md --- docker/README.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/docker/README.md b/docker/README.md index 3a538af..100bcbd 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,3 +1,9 @@ +# Dockerfiles for building the llama-cpp-python server +- `Dockerfile.openblas_simple` - a simple Dockerfile for non-GPU OpenBLAS +- `Dockerfile.cuda_simple` - a simple Dockerfile for CUDA accelerated CuBLAS +- `hug_model.py` - a Python utility for interactively choosing and downloading the latest `5_1` quantized models from [huggingface.co/TheBloke]( https://huggingface.co/TheBloke) +- `Dockerfile` - a single OpenBLAS and CuBLAS combined Dockerfile that automatically installs a previously downloaded model `model.bin` + # Get model from Hugging Face `python3 ./hug_model.py` @@ -7,7 +13,7 @@ docker $ ls -lh *.bin -rw-rw-r-- 1 user user 4.8G May 23 18:30 .q5_1.bin lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5_1.bin ``` -**Note #1:** Make sure you have enough disk space to d/l the model. As the model is then copied into the image you will need at least +**Note #1:** Make sure you have enough disk space to download the model. As the model is then copied into the image you will need at least **TWICE** as much disk space as the size of the model: | Model | Quantized size | @@ -21,20 +27,20 @@ lrwxrwxrwx 1 user user 24 May 23 18:30 model.bin -> .q5 # Install Docker Server -**Note #3:** This was tested with Docker running on Linux. If you can get it working on Windows or MacOS, please update this README with a PR! +**Note #3:** This was tested with Docker running on Linux. If you can get it working on Windows or MacOS, please update this `README.md` with a PR! [Install Docker Engine](https://docs.docker.com/engine/install) # Use OpenBLAS -No NVidia GPU, defaults to `python:3-slim-bullseye` Docker base image and OpenBlAS: +Use if you don't have a NVidia GPU. Defaults to `python:3-slim-bullseye` Docker base image and OpenBLAS: ## Build: `docker build --build-arg -t openblas .` ## Run: `docker run --cap-add SYS_RESOURCE -t openblas` # Use CuBLAS -Requires NVidia GPU and Docker NVidia support (see [container-toolkit/install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) +Requires a NVidia GPU with sufficient VRAM (approximately as much as the size above) and Docker NVidia support (see [container-toolkit/install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) ## Build: -`docker build --build-arg IMAGE=nvidia/cuda:12.1.1-devel-ubuntu22.04 -t opencuda .` +`docker build --build-arg IMAGE=nvidia/cuda:12.1.1-devel-ubuntu22.04 -t cublas .` ## Run: `docker run --cap-add SYS_RESOURCE -t cublas` From 5bb780d455d4158870c231a7fde1fa16863361f1 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Wed, 24 May 2023 21:55:44 +0200 Subject: [PATCH 05/11] Implemented logit processors and stop criteria's --- llama_cpp/llama.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 916fe07..cf1e719 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -316,6 +316,7 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, + logits_processors=None ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -328,6 +329,10 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] + for processor in logits_processors: + logits = processor(list(self.eval_tokens), logits) + + self.eval_logits[-1] = logits nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): @@ -436,6 +441,8 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, + logits_processors=None + ): """Sample a token from the model. @@ -468,6 +475,8 @@ class Llama: mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, + logits_processors=logits_processors + ) def generate( @@ -484,6 +493,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + logits_processors=None ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -541,6 +551,7 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + logits_processors=logits_processors ) tokens_or_none = yield token tokens = [token] @@ -637,6 +648,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + logits_processors=None, + stopping_criterias=None ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -700,6 +713,7 @@ class Llama: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, + logits_processors=logits_processors ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -707,6 +721,14 @@ class Llama: break completion_tokens.append(token) + for stopping_crit in stopping_criterias: + if stopping_crit(completion_tokens, None): + text = self.detokenize(completion_tokens) + finish_reason = "stop" + break + + if finish_reason == "stop": + break all_text = self.detokenize(completion_tokens) @@ -1006,6 +1028,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + logits_processors=None, + stopping_criterias=None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1048,6 +1072,9 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, + logits_processors=logits_processors, + stopping_criterias=stopping_criterias + ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks From c05fcdf42f991d5c43dea3377dc1529adcd45167 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Wed, 24 May 2023 22:02:06 +0200 Subject: [PATCH 06/11] Fixed none value of logits processors. --- llama_cpp/llama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index cf1e719..c6f540c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -320,6 +320,10 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 + + if logits_processors == None: + logits_processors = [] + n_vocab = self.n_vocab() n_ctx = self.n_ctx() top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k @@ -652,6 +656,10 @@ class Llama: stopping_criterias=None ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None + + if stopping_criterias == None: + stopping_criterias = [] + completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) completion_tokens: List[int] = [] From da463e6c8c3c09c7a32bf25d924974d74f3d2776 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Thu, 25 May 2023 09:07:16 +0200 Subject: [PATCH 07/11] Added types to logit processor list and stop criteria list --- llama_cpp/llama.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c6f540c..8176136 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,7 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple +from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable from collections import deque, OrderedDict from . import llama_cpp @@ -316,12 +316,11 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ): assert self.ctx is not None assert len(self.eval_logits) > 0 - - if logits_processors == None: + if logits_processors is None: logits_processors = [] n_vocab = self.n_vocab() @@ -445,7 +444,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ): """Sample a token from the model. @@ -497,7 +496,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -652,12 +651,12 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors=None, - stopping_criterias=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None - if stopping_criterias == None: + if stopping_criterias is None: stopping_criterias = [] completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -1036,8 +1035,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors=None, - stopping_criterias=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. From c2585b68894102ace0cb0c54dc812e27c36482b9 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Thu, 25 May 2023 10:54:08 +0200 Subject: [PATCH 08/11] Fixed list elements typing --- llama_cpp/llama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8176136..144671b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -316,7 +316,7 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, - logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None + logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -444,7 +444,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, - logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None + logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None ): """Sample a token from the model. @@ -496,7 +496,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None + logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -651,8 +651,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None, + logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -1035,8 +1035,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None + logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. From 0d2cc21202620e92fd152981c6b7ecc0190a6124 Mon Sep 17 00:00:00 2001 From: Gary Mulder Date: Thu, 25 May 2023 11:50:02 +0000 Subject: [PATCH 09/11] Fixed repeated imports --- docker/hug_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/docker/hug_model.py b/docker/hug_model.py index 476f53c..848a1aa 100644 --- a/docker/hug_model.py +++ b/docker/hug_model.py @@ -25,9 +25,6 @@ def check_magic_and_version(filename): return magic, version -import os -import requests - def download_file(url, destination): print(f"Downloading {url} to {destination}...") response = requests.get(url, stream=True) From 1d247e0f350948667553f3c880f8df40f0b5c787 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 25 May 2023 14:04:54 -0400 Subject: [PATCH 10/11] Add StoppingCriteria and LogitsProcessor to generate to match huggingface API --- llama_cpp/llama.py | 74 ++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 144671b..b7a8d79 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,17 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable +from typing import ( + List, + Optional, + Union, + Generator, + Sequence, + Iterator, + Deque, + Tuple, + Callable, +) from collections import deque, OrderedDict from . import llama_cpp @@ -72,6 +82,24 @@ class LlamaState: self.llama_state_size = llama_state_size +LogitsProcessor = Callable[[List[int], List[float]], List[float]] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[List[int], List[float]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__(self, input_ids: List[int], logits: List[float]) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -316,12 +344,10 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 - if logits_processors is None: - logits_processors = [] n_vocab = self.n_vocab() n_ctx = self.n_ctx() @@ -332,10 +358,10 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] - for processor in logits_processors: - logits = processor(list(self.eval_tokens), logits) - self.eval_logits[-1] = logits + if logits_processor is not None: + logits = logits_processor(list(self.eval_tokens), logits) + nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): @@ -444,8 +470,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None - + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -478,8 +503,7 @@ class Llama: mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, - logits_processors=logits_processors - + logits_processor=logits_processor, ) def generate( @@ -496,7 +520,8 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -554,8 +579,12 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, - logits_processors=logits_processors + logits_processor=logits_processor, ) + if stopping_criteria is not None and stopping_criteria( + list(self.eval_tokens), self.eval_logits[-1] + ): + return tokens_or_none = yield token tokens = [token] if tokens_or_none is not None: @@ -651,14 +680,9 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None - if stopping_criterias is None: - stopping_criterias = [] - completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) completion_tokens: List[int] = [] @@ -720,7 +744,6 @@ class Llama: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, - logits_processors=logits_processors ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -728,14 +751,6 @@ class Llama: break completion_tokens.append(token) - for stopping_crit in stopping_criterias: - if stopping_crit(completion_tokens, None): - text = self.detokenize(completion_tokens) - finish_reason = "stop" - break - - if finish_reason == "stop": - break all_text = self.detokenize(completion_tokens) @@ -1035,8 +1050,6 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1079,9 +1092,6 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, - logits_processors=logits_processors, - stopping_criterias=stopping_criterias - ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks From ca01f98e09f2f4146d8adb19efbd48460a99068c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 25 May 2023 14:11:33 -0400 Subject: [PATCH 11/11] Add LlamaTokenizer class --- llama_cpp/llama.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b7a8d79..7dd1acb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1380,6 +1380,11 @@ class Llama: assert self.ctx is not None return llama_cpp.llama_n_vocab(self.ctx) + def tokenizer(self) -> "LlamaTokenizer": + """Return the tokenizer for this model.""" + assert self.ctx is not None + return LlamaTokenizer(self) + @staticmethod def token_eos() -> int: """Return the end-of-sequence token.""" @@ -1410,3 +1415,18 @@ class Llama: else: break return longest_prefix + + +class LlamaTokenizer: + def __init__(self, llama: Llama): + self.llama = llama + + def encode(self, text: str) -> List[int]: + return self.llama.tokenize(text.encode("utf-8", errors="ignore")) + + def decode(self, tokens: List[int]) -> str: + return self.llama.detokenize(tokens).decode("utf-8", errors="ignore") + + @classmethod + def from_ggml_file(cls, path: str) -> "LlamaTokenizer": + return cls(Llama(model_path=path, vocab_only=True))