2024-01-17 09:16:13 -05:00
from __future__ import annotations
2023-03-24 15:47:17 -04:00
import os
2023-04-04 13:09:24 -04:00
import sys
2023-03-23 05:33:06 -04:00
import uuid
import time
2024-02-21 16:25:10 -05:00
import json
2024-02-23 11:24:53 -05:00
import ctypes
2024-05-29 02:02:22 -04:00
import typing
2024-02-21 16:25:10 -05:00
import fnmatch
2024-06-04 16:15:41 +02:00
import warnings
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
import contextlib
2023-03-23 05:33:06 -04:00
import multiprocessing
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
from types import TracebackType
2024-02-23 11:24:53 -05:00
2023-05-25 14:04:54 -04:00
from typing import (
List ,
Optional ,
Union ,
Generator ,
Sequence ,
Iterator ,
Deque ,
Callable ,
2024-04-17 09:06:50 -05:00
Dict ,
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
Type ,
2023-05-25 14:04:54 -04:00
)
2024-01-17 09:09:12 -05:00
from collections import deque
2024-02-21 16:25:10 -05:00
from pathlib import Path
2023-03-23 05:33:06 -04:00
2024-02-08 09:07:03 +08:00
from llama_cpp . llama_types import List
2023-04-01 13:01:27 -04:00
from . llama_types import *
2023-08-07 02:21:37 +09:00
from . llama_grammar import LlamaGrammar
2024-01-17 09:09:12 -05:00
from . llama_cache import (
BaseLlamaCache ,
LlamaCache , # type: ignore
LlamaDiskCache , # type: ignore
LlamaRAMCache , # type: ignore
)
2024-02-21 16:25:10 -05:00
from . llama_tokenizer import BaseLlamaTokenizer , LlamaTokenizer
2023-11-08 04:48:51 +01:00
import llama_cpp . llama_cpp as llama_cpp
2023-11-03 02:12:14 -04:00
import llama_cpp . llama_chat_format as llama_chat_format
2023-03-23 05:33:06 -04:00
2024-01-31 14:08:14 -05:00
from llama_cpp . llama_speculative import LlamaDraftModel
2023-05-26 16:12:45 -04:00
import numpy as np
import numpy . typing as npt
2024-01-17 09:14:00 -05:00
from . _internals import (
_LlamaModel , # type: ignore
_LlamaContext , # type: ignore
_LlamaBatch , # type: ignore
_LlamaTokenDataArray , # type: ignore
2024-01-31 14:08:14 -05:00
_LlamaSamplingParams , # type: ignore
_LlamaSamplingContext , # type: ignore
2024-04-25 20:32:44 -05:00
_normalize_embedding , # type: ignore
2024-01-17 09:14:00 -05:00
)
2024-02-05 21:52:12 -05:00
from . _logger import set_verbose
2024-02-21 16:25:10 -05:00
from . _utils import suppress_stdout_stderr
2023-07-18 19:27:41 -04:00
2023-09-28 22:42:03 -04:00
2023-03-23 05:33:06 -04:00
class Llama :
2023-03-24 18:57:59 -04:00
""" High-level Python wrapper for a llama.cpp model. """
2023-09-13 23:00:43 -04:00
__backend_initialized = False
2023-03-23 05:33:06 -04:00
def __init__ (
self ,
model_path : str ,
2023-09-13 21:19:47 -04:00
* ,
2023-09-28 22:42:03 -04:00
# Model Params
2023-06-08 13:19:23 -04:00
n_gpu_layers : int = 0 ,
2024-02-25 16:53:58 -05:00
split_mode : int = llama_cpp . LLAMA_SPLIT_MODE_LAYER ,
2023-09-13 21:20:26 -04:00
main_gpu : int = 0 ,
tensor_split : Optional [ List [ float ] ] = None ,
2024-06-04 22:38:21 +08:00
rpc_servers : Optional [ str ] = None ,
2023-09-28 22:42:03 -04:00
vocab_only : bool = False ,
use_mmap : bool = True ,
use_mlock : bool = False ,
2024-04-27 23:42:19 -04:00
kv_overrides : Optional [ Dict [ str , Union [ bool , int , float , str ] ] ] = None ,
2023-09-28 22:42:03 -04:00
# Context Params
seed : int = llama_cpp . LLAMA_DEFAULT_SEED ,
n_ctx : int = 512 ,
n_batch : int = 512 ,
n_threads : Optional [ int ] = None ,
n_threads_batch : Optional [ int ] = None ,
2024-02-25 16:53:58 -05:00
rope_scaling_type : Optional [ int ] = llama_cpp . LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
2024-03-14 10:04:57 -04:00
pooling_type : int = llama_cpp . LLAMA_POOLING_TYPE_UNSPECIFIED ,
2023-09-29 16:03:57 -04:00
rope_freq_base : float = 0.0 ,
rope_freq_scale : float = 0.0 ,
2023-11-03 11:34:50 -04:00
yarn_ext_factor : float = - 1.0 ,
2023-11-02 13:40:20 -04:00
yarn_attn_factor : float = 1.0 ,
yarn_beta_fast : float = 32.0 ,
yarn_beta_slow : float = 1.0 ,
yarn_orig_ctx : int = 0 ,
2023-03-23 05:33:06 -04:00
logits_all : bool = False ,
2023-03-25 16:26:23 -04:00
embedding : bool = False ,
2024-01-18 11:08:57 -05:00
offload_kqv : bool = True ,
2024-04-30 09:29:16 -04:00
flash_attn : bool = False ,
2023-09-28 22:42:03 -04:00
# Sampling Params
2023-04-01 13:01:27 -04:00
last_n_tokens_size : int = 64 ,
2023-09-28 22:42:03 -04:00
# LoRA Params
2023-06-08 13:19:23 -04:00
lora_base : Optional [ str ] = None ,
2023-09-28 22:42:03 -04:00
lora_scale : float = 1.0 ,
2023-06-08 13:19:23 -04:00
lora_path : Optional [ str ] = None ,
2023-09-28 22:42:03 -04:00
# Backend Params
2024-02-17 00:37:51 -05:00
numa : Union [ bool , int ] = False ,
2023-09-29 19:52:04 -04:00
# Chat Format Params
2024-01-29 14:22:23 -05:00
chat_format : Optional [ str ] = None ,
2023-11-08 04:48:51 +01:00
chat_handler : Optional [ llama_chat_format . LlamaChatCompletionHandler ] = None ,
2024-01-31 14:08:14 -05:00
# Speculative Decoding
draft_model : Optional [ LlamaDraftModel ] = None ,
2024-02-08 09:07:03 +08:00
# Tokenizer Override
tokenizer : Optional [ BaseLlamaTokenizer ] = None ,
2024-04-01 22:19:28 +08:00
# KV cache quantization
type_k : Optional [ int ] = None ,
type_v : Optional [ int ] = None ,
2023-09-28 22:42:03 -04:00
# Misc
2024-06-13 09:45:24 +02:00
spm_infill : bool = False ,
2023-04-04 13:09:24 -04:00
verbose : bool = True ,
2023-09-28 22:42:03 -04:00
# Extra Params
* * kwargs , # type: ignore
2023-04-01 13:01:27 -04:00
) :
2023-03-24 18:57:59 -04:00
""" Load a llama.cpp model from `model_path`.
2023-12-16 15:59:26 -08:00
2023-11-22 23:10:04 -05:00
Examples :
Basic usage
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . )
>> > print ( model ( " The quick brown fox jumps " , stop = [ " . " ] ) [ " choices " ] [ 0 ] [ " text " ] )
the lazy dog
Loading a chat model
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . chat_format = " llama-2 " ,
. . . )
>> > print ( model . create_chat_completion (
. . . messages = [ {
. . . " role " : " user " ,
. . . " content " : " what is the meaning of life? "
. . . } ]
. . . ) )
2023-03-24 18:57:59 -04:00
Args :
2023-03-25 12:33:18 -04:00
model_path : Path to the model .
2023-08-12 18:41:47 +08:00
n_gpu_layers : Number of layers to offload to GPU ( - ngl ) . If - 1 , all layers are offloaded .
2024-01-15 12:49:20 -05:00
split_mode : How to split the model across GPUs . See llama_cpp . LLAMA_SPLIT_ * for options .
main_gpu : main_gpu interpretation depends on split_mode : LLAMA_SPLIT_NONE : the GPU that is used for the entire model . LLAMA_SPLIT_ROW : the GPU that is used for small tensors and intermediate results . LLAMA_SPLIT_LAYER : ignored
2023-11-02 13:40:20 -04:00
tensor_split : How split tensors should be distributed across GPUs . If None , the model is not split .
2024-06-04 22:38:21 +08:00
rpc_servers : Comma separated list of RPC servers to use for offloading
2023-11-02 13:40:20 -04:00
vocab_only : Only load the vocabulary no weights .
use_mmap : Use mmap if possible .
use_mlock : Force the system to keep the model in RAM .
2024-01-15 17:29:29 +00:00
kv_overrides : Key - value overrides for the model .
2023-11-26 15:56:40 -05:00
seed : RNG seed , - 1 for random
n_ctx : Text context , 0 = from model
n_batch : Prompt processing maximum batch size
n_threads : Number of threads to use for generation
n_threads_batch : Number of threads to use for batch processing
rope_scaling_type : RoPE scaling type , from ` enum llama_rope_scaling_type ` . ref : https : / / github . com / ggerganov / llama . cpp / pull / 2054
2024-03-14 09:17:41 -04:00
pooling_type : Pooling type , from ` enum llama_pooling_type ` .
2023-11-26 15:56:40 -05:00
rope_freq_base : RoPE base frequency , 0 = from model
rope_freq_scale : RoPE frequency scaling factor , 0 = from model
yarn_ext_factor : YaRN extrapolation mix factor , negative = from model
yarn_attn_factor : YaRN magnitude scaling factor
yarn_beta_fast : YaRN low correction dim
yarn_beta_slow : YaRN high correction dim
yarn_orig_ctx : YaRN original context size
logits_all : Return logits for all tokens , not just the last token . Must be True for completion to return logprobs .
2023-03-25 16:26:23 -04:00
embedding : Embedding mode only .
2023-12-18 15:36:09 -05:00
offload_kqv : Offload K , Q , V to GPU .
2024-04-30 09:29:16 -04:00
flash_attn : Use flash attention .
2023-04-01 13:01:27 -04:00
last_n_tokens_size : Maximum number of tokens to keep in the last_n_tokens deque .
2023-04-18 10:20:46 -04:00
lora_base : Optional path to base model , useful if using a quantized base model and you want to apply LoRA to an f16 model .
2023-04-18 01:43:44 -04:00
lora_path : Path to a LoRA file to apply to the model .
2024-02-17 00:37:51 -05:00
numa : numa policy
2023-09-29 19:52:04 -04:00
chat_format : String specifying the chat format to use when calling create_chat_completion .
2023-11-08 04:48:51 +01:00
chat_handler : Optional chat handler to use when calling create_chat_completion .
2024-01-31 14:08:14 -05:00
draft_model : Optional draft model to use for speculative decoding .
2024-02-08 09:07:03 +08:00
tokenizer : Optional tokenizer to override the default tokenizer from llama . cpp .
2023-04-04 13:09:24 -04:00
verbose : Print verbose output to stderr .
2024-04-01 22:19:28 +08:00
type_k : KV cache data type for K ( default : f16 )
type_v : KV cache data type for V ( default : f16 )
2024-06-13 09:45:24 +02:00
spm_infill : Use Suffix / Prefix / Middle pattern for infill ( instead of Prefix / Suffix / Middle ) as some models prefer this .
2023-03-24 18:57:59 -04:00
Raises :
ValueError : If the model path does not exist .
Returns :
A Llama instance .
"""
2023-04-04 13:09:24 -04:00
self . verbose = verbose
2023-09-13 23:00:43 -04:00
2024-02-05 21:52:12 -05:00
set_verbose ( verbose )
2023-09-13 23:00:43 -04:00
if not Llama . __backend_initialized :
2024-02-12 15:56:07 -05:00
with suppress_stdout_stderr ( disable = verbose ) :
2024-02-17 00:37:51 -05:00
llama_cpp . llama_backend_init ( )
2023-09-13 23:00:43 -04:00
Llama . __backend_initialized = True
2024-02-17 00:37:51 -05:00
if isinstance ( numa , bool ) :
2024-02-21 16:25:10 -05:00
self . numa = (
llama_cpp . GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp . GGML_NUMA_STRATEGY_DISABLED
)
2024-02-17 01:02:33 -05:00
else :
self . numa = numa
2024-02-17 00:37:51 -05:00
if self . numa != llama_cpp . GGML_NUMA_STRATEGY_DISABLED :
with suppress_stdout_stderr ( disable = verbose ) :
llama_cpp . llama_numa_init ( self . numa )
2023-03-23 05:33:06 -04:00
self . model_path = model_path
2023-09-28 22:42:03 -04:00
# Model Params
self . model_params = llama_cpp . llama_model_default_params ( )
self . model_params . n_gpu_layers = (
0x7FFFFFFF if n_gpu_layers == - 1 else n_gpu_layers
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
2024-01-15 12:49:20 -05:00
self . model_params . split_mode = split_mode
2023-09-28 22:42:03 -04:00
self . model_params . main_gpu = main_gpu
2024-06-04 22:38:21 +08:00
if rpc_servers is not None :
self . model_params . rpc_servers = rpc_servers . encode ( ' utf-8 ' )
self . _rpc_servers = rpc_servers
else :
self . _rpc_servers = None
2023-07-15 15:11:01 -04:00
self . tensor_split = tensor_split
2023-12-22 15:12:27 -05:00
self . _c_tensor_split = None
2023-07-15 15:11:01 -04:00
if self . tensor_split is not None :
2023-10-15 10:51:51 -07:00
if len ( self . tensor_split ) > llama_cpp . LLAMA_MAX_DEVICES :
2023-11-06 09:16:36 -05:00
raise ValueError (
f " Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES= { llama_cpp . LLAMA_MAX_DEVICES } "
)
2023-07-18 19:27:41 -04:00
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
2023-09-13 20:00:42 -04:00
FloatArray = ctypes . c_float * llama_cpp . LLAMA_MAX_DEVICES
2023-07-18 19:27:41 -04:00
self . _c_tensor_split = FloatArray (
2023-09-28 22:42:03 -04:00
* tensor_split # type: ignore
2023-07-18 19:27:41 -04:00
) # keep a reference to the array so it is not gc'd
2023-09-28 22:42:03 -04:00
self . model_params . tensor_split = self . _c_tensor_split
self . model_params . vocab_only = vocab_only
self . model_params . use_mmap = use_mmap if lora_path is None else False
self . model_params . use_mlock = use_mlock
2023-07-15 15:11:01 -04:00
2024-01-24 03:00:38 +00:00
# kv_overrides is the original python dict
2024-01-15 17:29:29 +00:00
self . kv_overrides = kv_overrides
if kv_overrides is not None :
2024-01-24 03:00:38 +00:00
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
2024-01-23 22:08:27 -05:00
kvo_array_len = len ( kv_overrides ) + 1 # for sentinel element
self . _kv_overrides_array = (
llama_cpp . llama_model_kv_override * kvo_array_len
) ( )
2024-01-24 03:00:38 +00:00
for i , ( k , v ) in enumerate ( kv_overrides . items ( ) ) :
2024-01-23 22:08:27 -05:00
self . _kv_overrides_array [ i ] . key = k . encode ( " utf-8 " )
2024-01-23 22:28:03 -05:00
if isinstance ( v , bool ) :
2024-02-25 16:53:58 -05:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
2024-05-29 02:02:22 -04:00
self . _kv_overrides_array [ i ] . value . val_bool = v
2024-01-23 22:28:03 -05:00
elif isinstance ( v , int ) :
2024-02-25 16:53:58 -05:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
2024-05-29 02:02:22 -04:00
self . _kv_overrides_array [ i ] . value . val_i64 = v
2024-01-15 17:29:29 +00:00
elif isinstance ( v , float ) :
2024-02-25 16:53:58 -05:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
2024-05-29 02:02:22 -04:00
self . _kv_overrides_array [ i ] . value . val_f64 = v
2024-04-27 23:42:19 -04:00
elif isinstance ( v , str ) : # type: ignore
v_bytes = v . encode ( " utf-8 " )
if len ( v_bytes ) > 128 : # TODO: Make this a constant
raise ValueError ( f " Value for { k } is too long: { v } " )
v_bytes = v_bytes . ljust ( 128 , b " \0 " )
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
2024-05-03 19:07:50 -04:00
# copy min(v_bytes, 128) to str_value
2024-05-29 02:02:22 -04:00
address = typing . cast ( int , ctypes . addressof ( self . _kv_overrides_array [ i ] . value ) + llama_cpp . llama_model_kv_override_value . val_str . offset )
buffer_start = ctypes . cast ( address , ctypes . POINTER ( ctypes . c_char ) )
2024-05-03 19:07:50 -04:00
ctypes . memmove (
2024-05-29 02:02:22 -04:00
buffer_start ,
2024-05-03 19:07:50 -04:00
v_bytes ,
2024-05-29 02:02:22 -04:00
128 ,
2024-05-03 19:07:50 -04:00
)
2024-01-15 17:29:29 +00:00
else :
raise ValueError ( f " Unknown value type for { k } : { v } " )
2024-02-21 16:25:10 -05:00
self . _kv_overrides_array [ - 1 ] . key = (
b " \0 " # ensure sentinel element is zeroed
)
2024-01-15 17:29:29 +00:00
self . model_params . kv_overrides = self . _kv_overrides_array
2023-09-28 22:42:03 -04:00
self . n_batch = min ( n_ctx , n_batch ) # ???
self . n_threads = n_threads or max ( multiprocessing . cpu_count ( ) / / 2 , 1 )
2024-04-17 09:04:33 -05:00
self . n_threads_batch = n_threads_batch or multiprocessing . cpu_count ( )
2024-02-21 16:25:10 -05:00
2023-09-28 22:42:03 -04:00
# Context Params
self . context_params = llama_cpp . llama_context_default_params ( )
self . context_params . seed = seed
self . context_params . n_ctx = n_ctx
self . context_params . n_batch = self . n_batch
self . context_params . n_threads = self . n_threads
self . context_params . n_threads_batch = self . n_threads_batch
2023-11-02 13:40:20 -04:00
self . context_params . rope_scaling_type = (
2023-11-06 09:16:36 -05:00
rope_scaling_type
if rope_scaling_type is not None
2024-02-25 16:53:58 -05:00
else llama_cpp . LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
2023-11-02 13:40:20 -04:00
)
2024-03-14 09:17:41 -04:00
self . context_params . pooling_type = pooling_type
2023-09-29 16:03:57 -04:00
self . context_params . rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self . context_params . rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
2023-11-02 13:40:20 -04:00
self . context_params . yarn_ext_factor = (
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
)
self . context_params . yarn_attn_factor = (
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
)
self . context_params . yarn_beta_fast = (
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
)
self . context_params . yarn_beta_slow = (
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
2023-11-06 09:16:36 -05:00
self . context_params . yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
2024-02-21 16:25:10 -05:00
self . context_params . logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
2024-03-06 01:32:00 -05:00
self . context_params . embeddings = embedding # TODO: Rename to embeddings
2023-12-18 15:36:09 -05:00
self . context_params . offload_kqv = offload_kqv
2024-04-30 09:29:16 -04:00
self . context_params . flash_attn = flash_attn
2024-04-01 22:19:28 +08:00
# KV cache quantization
if type_k is not None :
self . context_params . type_k = type_k
if type_v is not None :
self . context_params . type_v = type_v
2023-09-28 22:42:03 -04:00
# Sampling Params
2023-04-01 13:01:27 -04:00
self . last_n_tokens_size = last_n_tokens_size
2023-03-23 05:33:06 -04:00
2023-09-28 22:42:03 -04:00
self . cache : Optional [ BaseLlamaCache ] = None
2023-03-23 05:33:06 -04:00
2023-04-25 09:00:53 -04:00
self . lora_base = lora_base
2023-09-28 22:42:03 -04:00
self . lora_scale = lora_scale
2023-04-25 09:00:53 -04:00
self . lora_path = lora_path
2024-06-13 09:45:24 +02:00
self . spm_infill = spm_infill
2023-03-24 15:47:17 -04:00
if not os . path . exists ( model_path ) :
raise ValueError ( f " Model path does not exist: { model_path } " )
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
self . _stack = contextlib . ExitStack ( )
self . _model = self . _stack . enter_context ( contextlib . closing ( _LlamaModel (
2023-11-06 09:16:36 -05:00
path_model = self . model_path , params = self . model_params , verbose = self . verbose
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
) ) )
2024-02-08 09:07:03 +08:00
# Override tokenizer
self . tokenizer_ = tokenizer or LlamaTokenizer ( self )
2023-12-17 00:59:50 +01:00
# Set the default value for the context and correct the batch
if n_ctx == 0 :
n_ctx = self . _model . n_ctx_train ( )
self . n_batch = min ( n_ctx , n_batch )
self . context_params . n_ctx = self . _model . n_ctx_train ( )
self . context_params . n_batch = self . n_batch
2023-03-23 05:33:06 -04:00
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
self . _ctx = self . _stack . enter_context ( contextlib . closing ( _LlamaContext (
2023-11-06 09:16:36 -05:00
model = self . _model ,
params = self . context_params ,
verbose = self . verbose ,
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
) ) )
2023-04-25 09:00:53 -04:00
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
self . _batch = self . _stack . enter_context ( contextlib . closing ( _LlamaBatch (
2023-11-06 09:16:36 -05:00
n_tokens = self . n_batch ,
embd = 0 ,
n_seq_max = self . context_params . n_ctx ,
verbose = self . verbose ,
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
) ) )
2023-11-02 20:13:57 -04:00
2023-04-18 23:45:25 -04:00
if self . lora_path :
2023-11-06 09:16:36 -05:00
if self . _model . apply_lora_from_file (
self . lora_path ,
2023-09-28 22:42:03 -04:00
self . lora_scale ,
2023-11-06 09:16:36 -05:00
self . lora_base ,
2023-09-13 21:11:52 -04:00
self . n_threads ,
2023-04-18 01:43:44 -04:00
) :
2023-04-18 23:45:25 -04:00
raise RuntimeError (
f " Failed to apply LoRA from lora path: { self . lora_path } to base path: { self . lora_base } "
)
2023-03-23 05:33:06 -04:00
2023-04-04 13:09:24 -04:00
if self . verbose :
print ( llama_cpp . llama_print_system_info ( ) . decode ( " utf-8 " ) , file = sys . stderr )
2023-11-03 02:12:14 -04:00
2023-09-29 19:52:04 -04:00
self . chat_format = chat_format
2023-11-08 04:48:51 +01:00
self . chat_handler = chat_handler
2024-05-09 15:49:09 +02:00
self . _chat_handlers : Dict [ str , llama_chat_format . LlamaChatCompletionHandler ] = { }
2023-04-04 13:09:24 -04:00
2024-01-31 14:08:14 -05:00
self . draft_model = draft_model
2023-05-23 17:56:21 -04:00
self . _n_vocab = self . n_vocab ( )
self . _n_ctx = self . n_ctx ( )
2023-11-06 09:16:36 -05:00
2023-08-24 00:17:00 -04:00
self . _token_nl = self . token_nl ( )
self . _token_eos = self . token_eos ( )
2023-11-06 09:16:36 -05:00
self . _candidates = _LlamaTokenDataArray ( n_vocab = self . _n_vocab )
2023-04-04 13:09:24 -04:00
2023-06-29 00:40:47 -04:00
self . n_tokens = 0
self . input_ids : npt . NDArray [ np . intc ] = np . ndarray ( ( n_ctx , ) , dtype = np . intc )
self . scores : npt . NDArray [ np . single ] = np . ndarray (
( n_ctx , self . _n_vocab ) , dtype = np . single
)
2024-01-23 22:08:27 -05:00
self . _mirostat_mu = ctypes . c_float (
2.0 * 5.0
) # TODO: Move this to sampling context
2024-01-19 08:31:59 -05:00
2024-01-19 10:46:03 -05:00
try :
self . metadata = self . _model . metadata ( )
except Exception as e :
self . metadata = { }
if self . verbose :
print ( f " Failed to load metadata: { e } " , file = sys . stderr )
2024-01-23 22:08:27 -05:00
2024-01-19 10:46:03 -05:00
if self . verbose :
print ( f " Model metadata: { self . metadata } " , file = sys . stderr )
2024-05-14 15:44:09 +02:00
eos_token_id = self . token_eos ( )
bos_token_id = self . token_bos ( )
2024-05-09 15:49:09 +02:00
2024-05-16 00:37:27 -04:00
eos_token = self . _model . token_get_text ( eos_token_id ) if eos_token_id != - 1 else " "
bos_token = self . _model . token_get_text ( bos_token_id ) if bos_token_id != - 1 else " "
2024-05-09 15:49:09 +02:00
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
template_choices = dict ( ( name [ 10 : ] , template ) for name , template in self . metadata . items ( ) if name . startswith ( " tokenizer.chat_template. " ) )
if " tokenizer.chat_template " in self . metadata :
template_choices [ " chat_template.default " ] = self . metadata [ " tokenizer.chat_template " ]
if self . verbose and template_choices :
print ( f " Available chat formats from metadata: { ' , ' . join ( template_choices . keys ( ) ) } " , file = sys . stderr )
for name , template in template_choices . items ( ) :
self . _chat_handlers [ name ] = llama_chat_format . Jinja2ChatFormatter (
template = template ,
eos_token = eos_token ,
bos_token = bos_token ,
stop_token_ids = [ eos_token_id ] ,
) . to_chat_handler ( )
2024-02-21 16:25:10 -05:00
if (
self . chat_format is None
and self . chat_handler is None
2024-05-09 15:49:09 +02:00
and " chat_template.default " in template_choices
2024-02-21 16:25:10 -05:00
) :
chat_format = llama_chat_format . guess_chat_format_from_gguf_metadata (
self . metadata
)
2024-01-29 14:22:23 -05:00
if chat_format is not None :
self . chat_format = chat_format
if self . verbose :
print ( f " Guessed chat format: { chat_format } " , file = sys . stderr )
else :
if self . verbose :
2024-05-09 15:49:09 +02:00
print ( f " Using gguf chat template: { template_choices [ ' chat_template.default ' ] } " , file = sys . stderr )
2024-01-29 14:22:23 -05:00
print ( f " Using chat eos_token: { eos_token } " , file = sys . stderr )
print ( f " Using chat bos_token: { bos_token } " , file = sys . stderr )
2024-05-09 15:49:09 +02:00
self . chat_format = " chat_template.default "
2024-01-29 14:22:23 -05:00
if self . chat_format is None and self . chat_handler is None :
self . chat_format = " llama-2 "
2024-03-01 13:10:25 -05:00
if self . verbose :
2024-05-07 23:19:35 -07:00
print ( f " Using fallback chat format: { self . chat_format } " , file = sys . stderr )
2024-01-29 14:22:23 -05:00
2023-11-06 09:16:36 -05:00
@property
def ctx ( self ) - > llama_cpp . llama_context_p :
assert self . _ctx . ctx is not None
return self . _ctx . ctx
@property
def model ( self ) - > llama_cpp . llama_model_p :
assert self . _model . model is not None
return self . _model . model
2023-06-29 00:40:47 -04:00
@property
def _input_ids ( self ) - > npt . NDArray [ np . intc ] :
return self . input_ids [ : self . n_tokens ]
@property
def _scores ( self ) - > npt . NDArray [ np . single ] :
return self . scores [ : self . n_tokens , : ]
@property
def eval_tokens ( self ) - > Deque [ int ] :
return deque ( self . input_ids [ : self . n_tokens ] . tolist ( ) , maxlen = self . _n_ctx )
@property
def eval_logits ( self ) - > Deque [ List [ float ] ] :
return deque (
self . scores [ : self . n_tokens , : ] . tolist ( ) ,
2023-09-30 16:02:35 -04:00
maxlen = self . _n_ctx if self . context_params . logits_all else 1 ,
2023-06-29 00:40:47 -04:00
)
2023-05-26 16:12:45 -04:00
2023-11-06 09:16:36 -05:00
def tokenize (
self , text : bytes , add_bos : bool = True , special : bool = False
) - > List [ int ] :
2023-03-28 01:45:37 -04:00
""" Tokenize a string.
Args :
text : The utf - 8 encoded string to tokenize .
2023-04-01 13:01:27 -04:00
Raises :
RuntimeError : If the tokenization failed .
2023-03-28 01:45:37 -04:00
Returns :
A list of tokens .
"""
2024-02-08 09:07:03 +08:00
return self . tokenizer_ . tokenize ( text , add_bos , special )
2023-03-28 01:45:37 -04:00
2024-02-21 16:25:10 -05:00
def detokenize (
self , tokens : List [ int ] , prev_tokens : Optional [ List [ int ] ] = None
) - > bytes :
2023-03-28 01:45:37 -04:00
""" Detokenize a list of tokens.
Args :
tokens : The list of tokens to detokenize .
2024-02-08 09:07:03 +08:00
prev_tokens : The list of previous tokens . Offset mapping will be performed if provided
2023-03-28 01:45:37 -04:00
Returns :
The detokenized string .
"""
2024-02-23 12:23:24 -05:00
return self . tokenizer_ . detokenize ( tokens , prev_tokens = prev_tokens )
2023-03-28 01:45:37 -04:00
2023-06-08 13:19:23 -04:00
def set_cache ( self , cache : Optional [ BaseLlamaCache ] ) :
2023-04-15 12:03:09 -04:00
""" Set the cache.
Args :
cache : The cache to set .
"""
2023-04-24 19:54:41 -04:00
self . cache = cache
2023-04-15 12:03:09 -04:00
2023-11-08 11:09:41 -05:00
def set_seed ( self , seed : int ) :
""" Set the random seed.
Args :
seed : The random seed .
"""
assert self . _ctx . ctx is not None
llama_cpp . llama_set_rng_seed ( self . _ctx . ctx , seed )
2023-04-02 00:02:47 -04:00
def reset ( self ) :
""" Reset the model state. """
2023-06-29 00:40:47 -04:00
self . n_tokens = 0
2023-04-02 00:02:47 -04:00
2023-05-19 11:59:33 -04:00
def eval ( self , tokens : Sequence [ int ] ) :
2023-04-02 00:02:47 -04:00
""" Evaluate a list of tokens.
Args :
tokens : The list of tokens to evaluate .
"""
2023-11-06 09:16:36 -05:00
assert self . _ctx . ctx is not None
assert self . _batch . batch is not None
2023-11-10 04:41:19 -05:00
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
2023-04-02 00:02:47 -04:00
for i in range ( 0 , len ( tokens ) , self . n_batch ) :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
2023-11-10 04:41:19 -05:00
n_past = self . n_tokens
2023-04-24 15:47:54 -04:00
n_tokens = len ( batch )
2023-11-06 09:16:36 -05:00
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
2023-04-02 00:02:47 -04:00
)
2023-11-06 09:16:36 -05:00
self . _ctx . decode ( self . _batch )
2023-05-01 14:47:55 -04:00
# Save tokens
2023-11-10 05:15:41 -05:00
self . input_ids [ n_past : n_past + n_tokens ] = batch
2023-05-01 14:47:55 -04:00
# Save logits
2024-04-03 15:30:31 -04:00
if self . context_params . logits_all :
rows = n_tokens
cols = self . _n_vocab
2024-05-24 00:49:44 -05:00
logits = np . ctypeslib . as_array ( self . _ctx . get_logits ( ) , shape = ( rows * cols , ) )
2024-04-03 15:30:31 -04:00
self . scores [ n_past : n_past + n_tokens , : ] . reshape ( - 1 ) [ : : ] = logits
else :
rows = 1
cols = self . _n_vocab
2024-05-24 00:49:44 -05:00
logits = np . ctypeslib . as_array ( self . _ctx . get_logits ( ) , shape = ( rows * cols , ) )
2024-04-03 15:30:31 -04:00
self . scores [ n_past + n_tokens - 1 , : ] . reshape ( - 1 ) [ : : ] = logits
2023-06-29 00:40:47 -04:00
# Update n_tokens
self . n_tokens + = n_tokens
2023-05-01 14:47:55 -04:00
2023-11-06 09:16:36 -05:00
def sample (
2023-06-08 13:19:23 -04:00
self ,
2023-11-06 09:16:36 -05:00
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-06 09:16:36 -05:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_eta : float = 0.1 ,
mirostat_tau : float = 5.0 ,
2023-06-08 13:19:23 -04:00
penalize_nl : bool = True ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2024-01-31 14:08:14 -05:00
idx : Optional [ int ] = None ,
2023-05-01 14:47:55 -04:00
) :
2023-11-06 09:16:36 -05:00
""" Sample a token from the model.
Args :
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
Returns :
The sampled token .
"""
assert self . _ctx is not None
2023-06-29 00:40:47 -04:00
assert self . n_tokens > 0
2024-01-31 14:08:14 -05:00
if idx is None :
logits : npt . NDArray [ np . single ] = self . _scores [ - 1 , : ]
else :
logits = self . _scores [ idx , : ]
2023-05-24 21:55:44 +02:00
2023-05-25 14:04:54 -04:00
if logits_processor is not None :
2024-01-31 14:08:14 -05:00
logits [ : ] = (
logits_processor ( self . _input_ids , logits )
if idx is None
2024-02-21 16:25:10 -05:00
else logits_processor ( self . _input_ids [ : idx + 1 ] , logits )
2024-01-31 14:08:14 -05:00
)
sampling_params = _LlamaSamplingParams (
top_k = top_k ,
top_p = top_p ,
min_p = min_p ,
tfs_z = tfs_z ,
typical_p = typical_p ,
temp = temp ,
penalty_last_n = self . last_n_tokens_size ,
2023-10-24 03:13:32 -04:00
penalty_repeat = repeat_penalty ,
penalty_freq = frequency_penalty ,
penalty_present = presence_penalty ,
2024-01-31 14:08:14 -05:00
mirostat = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
penalize_nl = penalize_nl ,
)
sampling_context = _LlamaSamplingContext (
params = sampling_params ,
grammar = grammar ,
)
sampling_context . prev = list ( self . eval_tokens )
id = sampling_context . sample ( ctx_main = self . _ctx , logits_array = logits )
sampling_context . accept (
ctx_main = self . _ctx ,
id = id ,
apply_grammar = grammar is not None ,
2023-05-08 21:21:25 -04:00
)
2023-08-07 02:21:37 +09:00
return id
2023-04-02 00:02:47 -04:00
2023-04-01 13:01:27 -04:00
def generate (
self ,
2023-06-08 13:19:23 -04:00
tokens : Sequence [ int ] ,
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-06-08 13:19:23 -04:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
reset : bool = True ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
2024-01-10 03:46:27 -04:00
penalize_nl : bool = True ,
2023-06-08 13:19:23 -04:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-05-19 11:59:33 -04:00
) - > Generator [ int , Optional [ Sequence [ int ] ] , None ] :
2023-04-02 00:02:47 -04:00
""" Create a generator of tokens from a prompt.
2023-04-01 17:36:30 -04:00
2023-04-01 17:39:35 -04:00
Examples :
>> > llama = Llama ( " models/ggml-7b.bin " )
>> > tokens = llama . tokenize ( b " Hello, world! " )
>> > for token in llama . generate ( tokens , top_k = 40 , top_p = 0.95 , temp = 1.0 , repeat_penalty = 1.1 ) :
. . . print ( llama . detokenize ( [ token ] ) )
2023-04-01 17:36:30 -04:00
Args :
tokens : The prompt tokens .
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
2023-04-13 00:28:00 -04:00
reset : Whether to reset the model state .
2023-04-01 17:36:30 -04:00
Yields :
The generated tokens .
"""
2024-01-19 08:31:59 -05:00
# Reset mirostat sampling
self . _mirostat_mu = ctypes . c_float ( 2.0 * mirostat_tau )
# Check for kv cache prefix match
2023-11-06 09:16:36 -05:00
if reset and self . n_tokens > 0 :
2023-05-04 21:58:27 -04:00
longest_prefix = 0
2023-05-26 20:03:31 -04:00
for a , b in zip ( self . _input_ids , tokens [ : - 1 ] ) :
2023-05-04 21:58:27 -04:00
if a == b :
longest_prefix + = 1
else :
break
if longest_prefix > 0 :
if self . verbose :
print ( " Llama.generate: prefix-match hit " , file = sys . stderr )
reset = False
tokens = tokens [ longest_prefix : ]
2023-06-29 00:40:47 -04:00
self . n_tokens = longest_prefix
2023-04-24 19:54:41 -04:00
2024-01-19 08:31:59 -05:00
# Reset the model state
2023-04-13 00:28:00 -04:00
if reset :
self . reset ( )
2023-05-04 21:58:27 -04:00
2024-01-19 08:31:59 -05:00
# Reset the grammar
2023-08-08 15:08:54 -04:00
if grammar is not None :
grammar . reset ( )
2023-08-07 15:16:25 +09:00
2024-01-31 14:08:14 -05:00
sample_idx = self . n_tokens + len ( tokens ) - 1
tokens = list ( tokens )
2024-01-19 08:31:59 -05:00
# Eval and sample
2023-04-01 13:01:27 -04:00
while True :
2023-04-02 00:02:47 -04:00
self . eval ( tokens )
2024-01-31 14:08:14 -05:00
while sample_idx < self . n_tokens :
token = self . sample (
top_k = top_k ,
top_p = top_p ,
min_p = min_p ,
typical_p = typical_p ,
temp = temp ,
repeat_penalty = repeat_penalty ,
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
logits_processor = logits_processor ,
grammar = grammar ,
penalize_nl = penalize_nl ,
idx = sample_idx ,
)
sample_idx + = 1
if stopping_criteria is not None and stopping_criteria (
self . _input_ids , self . _scores [ - 1 , : ]
) :
return
tokens_or_none = yield token
tokens . clear ( )
tokens . append ( token )
if tokens_or_none is not None :
tokens . extend ( tokens_or_none )
if sample_idx < self . n_tokens and token != self . _input_ids [ sample_idx ] :
self . n_tokens = sample_idx
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
break
if self . draft_model is not None :
self . input_ids [ self . n_tokens : self . n_tokens + len ( tokens ) ] = tokens
2024-02-21 16:25:10 -05:00
draft_tokens = self . draft_model (
self . input_ids [ : self . n_tokens + len ( tokens ) ]
)
2024-01-31 14:08:14 -05:00
tokens . extend (
draft_tokens . astype ( int ) [
: self . _n_ctx - self . n_tokens - len ( tokens )
]
)
2023-04-01 13:01:27 -04:00
2023-05-20 01:23:32 +02:00
def create_embedding (
2023-06-08 13:19:23 -04:00
self , input : Union [ str , List [ str ] ] , model : Optional [ str ] = None
2023-09-28 22:42:03 -04:00
) - > CreateEmbeddingResponse :
2023-03-28 04:59:54 -04:00
""" Embed a string.
Args :
2023-04-01 13:01:27 -04:00
input : The utf - 8 encoded string to embed .
2023-03-28 04:59:54 -04:00
Returns :
2023-04-01 13:01:27 -04:00
An embedding object .
2023-03-28 04:59:54 -04:00
"""
2023-11-06 09:16:36 -05:00
assert self . _model . model is not None
2023-05-16 18:07:25 -04:00
model_name : str = model if model is not None else self . model_path
2023-04-04 13:09:24 -04:00
2024-02-15 16:09:48 -05:00
input = input if isinstance ( input , list ) else [ input ]
2024-02-14 03:26:09 -06:00
# get numeric embeddings
2024-04-25 20:32:44 -05:00
embeds : Union [ List [ List [ float ] ] , List [ List [ List [ float ] ] ] ]
2024-02-14 03:26:09 -06:00
total_tokens : int
embeds , total_tokens = self . embed ( input , return_count = True ) # type: ignore
# convert to CreateEmbeddingResponse
data : List [ Embedding ] = [
{
" object " : " embedding " ,
" embedding " : emb ,
" index " : idx ,
}
for idx , emb in enumerate ( embeds )
]
return {
" object " : " list " ,
" data " : data ,
" model " : model_name ,
" usage " : {
" prompt_tokens " : total_tokens ,
" total_tokens " : total_tokens ,
} ,
}
def embed (
self ,
input : Union [ str , List [ str ] ] ,
2024-04-25 20:32:44 -05:00
normalize : bool = False ,
2024-02-14 03:26:09 -06:00
truncate : bool = True ,
return_count : bool = False ,
) :
""" Embed a string.
Args :
input : The utf - 8 encoded string to embed .
Returns :
A list of embeddings
"""
assert self . _ctx . ctx is not None
n_embd = self . n_embd ( )
2024-02-15 14:16:30 -06:00
n_batch = self . n_batch
2024-02-14 03:26:09 -06:00
2024-04-25 20:32:44 -05:00
# get pooling information
pooling_type = self . pooling_type ( )
logits_all = pooling_type == llama_cpp . LLAMA_POOLING_TYPE_NONE
2024-03-06 01:32:00 -05:00
if self . context_params . embeddings == False :
2023-04-05 03:25:37 -04:00
raise RuntimeError (
" Llama model must be created with embedding=True to call this method "
)
2023-04-04 13:09:24 -04:00
if self . verbose :
2023-11-06 09:16:36 -05:00
llama_cpp . llama_reset_timings ( self . _ctx . ctx )
2023-04-04 13:09:24 -04:00
2023-05-20 01:23:32 +02:00
if isinstance ( input , str ) :
inputs = [ input ]
else :
inputs = input
2023-04-04 13:09:24 -04:00
2024-02-14 03:26:09 -06:00
# reset batch
self . _batch . reset ( )
# decode and fetch embeddings
2024-04-25 20:32:44 -05:00
data : Union [ List [ List [ float ] ] , List [ List [ List [ float ] ] ] ] = [ ]
2024-02-21 16:25:10 -05:00
2024-04-25 20:32:44 -05:00
def decode_batch ( seq_sizes : List [ int ] ) :
2024-02-14 03:26:09 -06:00
assert self . _ctx . ctx is not None
llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
self . _ctx . decode ( self . _batch )
self . _batch . reset ( )
# store embeddings
2024-04-25 20:32:44 -05:00
if pooling_type == llama_cpp . LLAMA_POOLING_TYPE_NONE :
pos : int = 0
for i , size in enumerate ( seq_sizes ) :
ptr = llama_cpp . llama_get_embeddings ( self . _ctx . ctx )
embedding : List [ List [ float ] ] = [
ptr [ pos + j * n_embd : pos + ( j + 1 ) * n_embd ] for j in range ( size )
]
if normalize :
embedding = [ _normalize_embedding ( e ) for e in embedding ]
data . append ( embedding )
pos + = size
else :
for i in range ( len ( seq_sizes ) ) :
ptr = llama_cpp . llama_get_embeddings_seq ( self . _ctx . ctx , i )
embedding : List [ float ] = ptr [ : n_embd ]
if normalize :
embedding = _normalize_embedding ( embedding )
data . append ( embedding )
2024-02-14 03:26:09 -06:00
# init state
2023-05-20 01:23:32 +02:00
total_tokens = 0
2024-04-25 20:32:44 -05:00
s_batch = [ ]
2024-02-14 03:26:09 -06:00
t_batch = 0
2024-02-15 14:16:30 -06:00
p_batch = 0
2024-02-14 03:26:09 -06:00
# accumulate batches and encode
for text in inputs :
tokens = self . tokenize ( text . encode ( " utf-8 " ) )
if truncate :
2024-02-15 14:16:30 -06:00
tokens = tokens [ : n_batch ]
2024-02-14 03:26:09 -06:00
2023-05-20 01:23:32 +02:00
n_tokens = len ( tokens )
total_tokens + = n_tokens
2024-02-14 03:26:09 -06:00
# check for overrun
2024-02-15 14:16:30 -06:00
if n_tokens > n_batch :
2024-02-14 03:26:09 -06:00
raise ValueError (
2024-02-15 14:16:30 -06:00
f " Requested tokens ( { n_tokens } ) exceed batch size of { n_batch } "
2024-02-14 03:26:09 -06:00
)
# time to eval batch
2024-02-15 14:16:30 -06:00
if t_batch + n_tokens > n_batch :
2024-04-25 20:32:44 -05:00
decode_batch ( s_batch )
s_batch = [ ]
2024-02-14 03:26:09 -06:00
t_batch = 0
2024-02-15 14:16:30 -06:00
p_batch = 0
2024-02-14 03:26:09 -06:00
# add to batch
2024-04-25 20:32:44 -05:00
self . _batch . add_sequence ( tokens , p_batch , logits_all )
# update batch stats
s_batch . append ( n_tokens )
2024-02-14 03:26:09 -06:00
t_batch + = n_tokens
2024-02-15 14:16:30 -06:00
p_batch + = 1
2024-02-14 03:26:09 -06:00
# hanlde last batch
2024-04-25 20:32:44 -05:00
decode_batch ( s_batch )
2024-02-14 03:26:09 -06:00
2023-05-21 21:30:03 -04:00
if self . verbose :
2023-11-06 09:16:36 -05:00
llama_cpp . llama_print_timings ( self . _ctx . ctx )
2023-05-20 01:23:32 +02:00
2024-02-14 03:26:09 -06:00
output = data [ 0 ] if isinstance ( input , str ) else data
2023-03-28 02:42:22 -04:00
2024-02-14 03:26:09 -06:00
llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
self . reset ( )
2023-04-03 18:46:19 -04:00
2024-02-14 03:26:09 -06:00
if return_count :
return output , total_tokens
else :
return output
2023-04-03 18:46:19 -04:00
2023-04-01 13:01:27 -04:00
def _create_completion (
2023-03-23 05:33:06 -04:00
self ,
2023-11-08 04:48:51 +01:00
prompt : Union [ str , List [ int ] ] ,
2023-03-23 05:33:06 -04:00
suffix : Optional [ str ] = None ,
2023-11-10 02:49:27 -05:00
max_tokens : Optional [ int ] = 16 ,
2023-03-23 05:33:06 -04:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-23 15:51:05 -04:00
logprobs : Optional [ int ] = None ,
2023-03-23 05:33:06 -04:00
echo : bool = False ,
2023-06-08 13:19:23 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-23 05:33:06 -04:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
2023-03-28 04:03:57 -04:00
stream : bool = False ,
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = None ,
2023-06-08 13:19:23 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 04:01:36 -05:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 04:48:51 +01:00
) - > Union [
Iterator [ CreateCompletionResponse ] , Iterator [ CreateCompletionStreamResponse ]
] :
2023-11-06 09:16:36 -05:00
assert self . _ctx is not None
2023-11-01 23:52:50 +01:00
assert suffix is None or suffix . __class__ is str
2023-05-24 22:02:06 +02:00
2023-04-15 11:39:21 -04:00
completion_id : str = f " cmpl- { str ( uuid . uuid4 ( ) ) } "
created : int = int ( time . time ( ) )
2024-06-13 09:45:24 +02:00
bos_token_id : int = self . token_bos ( )
cls_token_id : int = self . _model . token_cls ( )
sep_token_id : int = self . _model . token_sep ( )
2024-05-14 15:44:09 +02:00
prefix_token_id : int = self . _model . token_prefix ( )
middle_token_id : int = self . _model . token_middle ( )
suffix_token_id : int = self . _model . token_suffix ( )
2024-06-13 09:45:24 +02:00
add_space_prefix : bool = self . metadata . get ( " tokenizer.ggml.add_space_prefix " , " true " ) == " true "
bos_tokens : List [ int ] = [ cls_token_id if cls_token_id != - 1 else bos_token_id ]
eos_tokens : List [ int ] = [ sep_token_id if sep_token_id != - 1 else self . token_eos ( ) ]
if ( isinstance ( prompt , list ) and suffix is None ) or self . _model . add_bos_token ( ) == 0 or bos_tokens [ : 1 ] == [ - 1 ] :
bos_tokens = [ ]
if ( isinstance ( prompt , list ) and suffix is None ) or ( self . _model . add_eos_token ( ) != 1 and sep_token_id == - 1 ) :
eos_tokens = [ ]
suffix_space_prefix : int = 0
# Tokenizer hack to remove leading space
if add_space_prefix and suffix_token_id > = 0 and suffix :
suffix = " ☺ " + suffix
suffix_space_prefix = 2
2023-11-20 22:50:59 -05:00
# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
2024-06-13 09:45:24 +02:00
completion_tokens : List [ int ] = [ ] if len ( prompt ) > 0 else [ bos_token_id ]
2023-04-01 13:01:27 -04:00
# Add blank space to start of prompt to match OG llama tokenizer
2024-06-13 09:45:24 +02:00
prefix_tokens : List [ int ] = (
2023-11-08 11:09:41 -05:00
(
2024-05-08 08:26:22 +02:00
[ prefix_token_id ]
if prefix_token_id > = 0 and suffix is not None
else [ ]
)
+
(
(
2024-06-13 09:45:24 +02:00
self . tokenize ( prompt . encode ( " utf-8 " ) , add_bos = False , special = ( prefix_token_id < 0 or suffix is None ) )
2024-05-08 08:26:22 +02:00
if prompt != " "
2024-06-13 09:45:24 +02:00
else [ ]
2024-05-08 08:26:22 +02:00
)
if isinstance ( prompt , str )
else prompt
)
2024-06-13 09:45:24 +02:00
)
suffix_tokens : List [ int ] = (
2024-05-08 08:26:22 +02:00
(
2024-06-13 09:45:24 +02:00
[ suffix_token_id ]
+
2024-05-08 08:26:22 +02:00
(
2024-06-13 09:45:24 +02:00
self . tokenize ( suffix . encode ( " utf-8 " ) , add_bos = False , special = False ) [ suffix_space_prefix : ]
if suffix
else [ ]
2024-05-08 08:26:22 +02:00
)
2023-11-08 11:09:41 -05:00
)
2024-06-13 09:45:24 +02:00
if suffix_token_id > = 0 and suffix is not None
else [ ]
)
middle_tokens : List [ int ] = (
[ middle_token_id ]
if middle_token_id > = 0 and suffix is not None
else [ ]
2023-11-08 11:09:41 -05:00
)
2024-06-13 09:45:24 +02:00
prompt_tokens : List [ int ] = bos_tokens + ( ( suffix_tokens + prefix_tokens + middle_tokens ) if self . spm_infill else ( prefix_tokens + suffix_tokens + middle_tokens ) ) + eos_tokens
2023-04-15 11:39:21 -04:00
text : bytes = b " "
2023-05-18 11:35:59 -04:00
returned_tokens : int = 0
2023-05-19 11:59:33 -04:00
stop = (
stop if isinstance ( stop , list ) else [ stop ] if isinstance ( stop , str ) else [ ]
)
2023-05-16 18:07:25 -04:00
model_name : str = model if model is not None else self . model_path
2023-03-23 05:33:06 -04:00
2024-06-04 16:15:41 +02:00
if prompt_tokens [ : 2 ] == [ self . token_bos ( ) ] * 2 :
warnings . warn (
f ' Detected duplicate leading " { self . _model . token_get_text ( self . token_bos ( ) ) } " in prompt, this will likely reduce response quality, consider removing it... ' ,
RuntimeWarning ,
)
2023-11-21 03:59:46 -05:00
# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None :
logit_bias_map = { int ( k ) : float ( v ) for k , v in logit_bias . items ( ) }
def logit_bias_processor (
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
) - > npt . NDArray [ np . single ] :
new_scores = np . copy (
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id , score in logit_bias_map . items ( ) :
new_scores [ input_id ] = score + scores [ input_id ]
return new_scores
_logit_bias_processor = LogitsProcessorList ( [ logit_bias_processor ] )
if logits_processor is None :
logits_processor = _logit_bias_processor
else :
logits_processor = logits_processor . extend ( _logit_bias_processor )
2023-04-04 13:09:24 -04:00
if self . verbose :
2023-11-06 09:16:36 -05:00
self . _ctx . reset_timings ( )
2023-04-04 13:09:24 -04:00
2023-11-06 09:16:36 -05:00
if len ( prompt_tokens ) > = self . _n_ctx :
2023-03-23 05:33:06 -04:00
raise ValueError (
2023-11-08 04:48:51 +01:00
f " Requested tokens ( { len ( prompt_tokens ) } ) exceed context window of { llama_cpp . llama_n_ctx ( self . ctx ) } "
2023-03-23 05:33:06 -04:00
)
2023-11-10 02:49:27 -05:00
if max_tokens is None or max_tokens < = 0 :
2023-07-09 18:13:29 -04:00
# Unlimited, depending on n_ctx.
2023-11-06 09:16:36 -05:00
max_tokens = self . _n_ctx - len ( prompt_tokens )
2023-07-09 18:13:29 -04:00
2023-06-09 10:57:36 -04:00
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (
max_tokens
if max_tokens + len ( prompt_tokens ) < self . _n_ctx
else ( self . _n_ctx - len ( prompt_tokens ) )
)
2023-04-01 13:01:27 -04:00
if stop != [ ] :
2023-04-02 03:59:19 -04:00
stop_sequences = [ s . encode ( " utf-8 " ) for s in stop ]
2023-04-01 13:01:27 -04:00
else :
2023-04-02 03:59:19 -04:00
stop_sequences = [ ]
2023-03-24 14:33:38 -04:00
2023-09-30 16:02:35 -04:00
if logprobs is not None and self . context_params . logits_all is False :
2023-04-12 14:05:11 -04:00
raise ValueError (
" logprobs is not supported for models created with logits_all=False "
)
2023-06-10 12:22:31 -04:00
if self . cache :
2023-05-07 19:31:26 -04:00
try :
cache_item = self . cache [ prompt_tokens ]
cache_prefix_len = Llama . longest_token_prefix (
2023-05-26 20:12:05 -04:00
cache_item . input_ids . tolist ( ) , prompt_tokens
2023-05-07 19:31:26 -04:00
)
eval_prefix_len = Llama . longest_token_prefix (
2023-05-26 20:12:05 -04:00
self . _input_ids . tolist ( ) , prompt_tokens
2023-05-07 19:31:26 -04:00
)
if cache_prefix_len > eval_prefix_len :
self . load_state ( cache_item )
if self . verbose :
print ( " Llama._create_completion: cache hit " , file = sys . stderr )
except KeyError :
if self . verbose :
print ( " Llama._create_completion: cache miss " , file = sys . stderr )
2023-11-08 11:09:41 -05:00
2023-11-07 23:37:28 -05:00
if seed is not None :
self . _ctx . set_rng_seed ( seed )
2023-04-15 12:03:09 -04:00
2023-04-12 14:05:11 -04:00
finish_reason = " length "
2023-04-28 12:50:30 +02:00
multibyte_fix = 0
2023-04-01 13:01:27 -04:00
for token in self . generate (
prompt_tokens ,
top_k = top_k ,
top_p = top_p ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 13:01:27 -04:00
temp = temperature ,
2023-06-08 13:19:23 -04:00
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 13:01:27 -04:00
repeat_penalty = repeat_penalty ,
2023-06-08 13:19:23 -04:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 15:08:54 -04:00
grammar = grammar ,
2023-03-28 04:03:57 -04:00
) :
2024-04-22 00:35:47 -04:00
assert self . _model . model is not None
if llama_cpp . llama_token_is_eog ( self . _model . model , token ) :
2024-02-23 12:23:24 -05:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-03-23 05:33:06 -04:00
finish_reason = " stop "
break
2023-04-24 19:54:41 -04:00
2023-03-28 01:45:37 -04:00
completion_tokens . append ( token )
2023-03-23 05:33:06 -04:00
2024-02-23 12:23:24 -05:00
all_text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-04-28 13:16:18 +02:00
# Contains multi-byte UTF8
2023-05-01 14:47:55 -04:00
for k , char in enumerate ( all_text [ - 3 : ] ) :
2023-04-28 13:16:18 +02:00
k = 3 - k
2023-05-01 14:47:55 -04:00
for num , pattern in [ ( 2 , 192 ) , ( 3 , 224 ) , ( 4 , 240 ) ] :
2023-04-28 13:16:18 +02:00
# Bitwise AND check
2023-05-01 14:47:55 -04:00
if num > k and pattern & char == pattern :
2023-04-28 13:16:18 +02:00
multibyte_fix = num - k
2023-04-28 12:50:30 +02:00
# Stop incomplete bytes from passing
2023-05-01 14:47:55 -04:00
if multibyte_fix > 0 :
2023-04-28 12:50:30 +02:00
multibyte_fix - = 1
continue
2023-04-02 03:59:19 -04:00
any_stop = [ s for s in stop_sequences if s in all_text ]
2023-03-23 05:33:06 -04:00
if len ( any_stop ) > 0 :
first_stop = any_stop [ 0 ]
2023-04-02 03:59:19 -04:00
text = all_text [ : all_text . index ( first_stop ) ]
2023-03-23 05:33:06 -04:00
finish_reason = " stop "
break
2023-03-28 04:03:57 -04:00
if stream :
2023-05-26 20:23:49 -04:00
remaining_tokens = completion_tokens [ returned_tokens : ]
2024-02-23 12:23:24 -05:00
remaining_text = self . detokenize ( remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-26 20:23:49 -04:00
remaining_length = len ( remaining_text )
2023-04-02 03:59:19 -04:00
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
2023-05-19 02:20:27 -04:00
first_stop_position = 0
2023-04-02 03:59:19 -04:00
for s in stop_sequences :
2023-05-26 20:23:49 -04:00
for i in range ( min ( len ( s ) , remaining_length ) , 0 , - 1 ) :
if remaining_text . endswith ( s [ : i ] ) :
2023-05-19 02:20:27 -04:00
if i > first_stop_position :
first_stop_position = i
2023-03-28 04:03:57 -04:00
break
2023-05-18 11:35:59 -04:00
2023-05-19 02:20:27 -04:00
token_end_position = 0
2023-08-09 22:04:35 +08:00
if logprobs is not None :
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens :
2024-06-13 09:45:24 +02:00
if token == bos_token_id :
2023-11-20 22:50:59 -05:00
continue
2024-02-23 12:23:24 -05:00
token_end_position + = len ( self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) )
2023-08-09 22:04:35 +08:00
# Check if stop sequence is in the token
2023-09-28 22:42:03 -04:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 22:04:35 +08:00
break
2024-02-23 12:23:24 -05:00
token_str = self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2023-05-19 02:20:27 -04:00
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
2024-02-23 12:23:24 -05:00
self . detokenize ( completion_tokens [ : returned_tokens ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2023-12-22 14:03:29 +09:00
" utf-8 " , errors = " ignore "
)
2023-05-19 02:20:27 -04:00
)
token_offset = len ( prompt_tokens ) + returned_tokens
2023-12-18 11:28:12 -08:00
logits = self . _scores [ token_offset - 1 , : ]
2023-12-18 18:40:36 -05:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 02:20:27 -04:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 11:59:33 -04:00
self . detokenize ( [ i ] ) . decode (
2023-05-19 02:20:27 -04:00
" utf-8 " , errors = " ignore "
) : logprob
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
2024-02-09 02:02:13 -05:00
" tokens " : [
2024-02-23 12:23:24 -05:00
self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2024-02-09 02:02:13 -05:00
" utf-8 " , errors = " ignore "
)
] ,
2023-05-19 02:20:27 -04:00
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 02:20:27 -04:00
" top_logprobs " : [ top_logprob ] ,
}
2023-08-09 22:04:35 +08:00
returned_tokens + = 1
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
2024-02-23 12:23:24 -05:00
" text " : self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2024-02-09 02:02:13 -05:00
" utf-8 " , errors = " ignore "
) ,
2023-08-09 22:04:35 +08:00
" index " : 0 ,
" logprobs " : logprobs_or_none ,
" finish_reason " : None ,
}
] ,
}
else :
while len ( remaining_tokens ) > 0 :
decode_success = False
for i in range ( 1 , len ( remaining_tokens ) + 1 ) :
try :
2024-02-23 12:23:24 -05:00
bs = self . detokenize ( remaining_tokens [ : i ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-09-28 22:42:03 -04:00
ts = bs . decode ( " utf-8 " )
2023-08-09 22:04:35 +08:00
decode_success = True
break
except UnicodeError :
pass
2023-08-29 07:21:59 -04:00
else :
break
2023-08-09 22:04:35 +08:00
if not decode_success :
# all remaining tokens cannot be decoded to a UTF-8 character
break
token_end_position + = len ( bs )
2023-09-28 22:42:03 -04:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 22:04:35 +08:00
break
remaining_tokens = remaining_tokens [ i : ]
returned_tokens + = i
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
2023-08-29 07:21:59 -04:00
" text " : ts ,
2023-08-09 22:04:35 +08:00
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
}
2023-04-12 14:05:11 -04:00
2023-04-02 03:59:19 -04:00
if len ( completion_tokens ) > = max_tokens :
2024-02-23 12:23:24 -05:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-04-02 03:59:19 -04:00
finish_reason = " length "
break
2023-03-23 05:33:06 -04:00
2023-05-26 03:13:24 -04:00
if stopping_criteria is not None and stopping_criteria (
2023-07-18 19:27:41 -04:00
self . _input_ids , self . _scores [ - 1 , : ]
2023-05-26 03:13:24 -04:00
) :
2024-02-23 12:23:24 -05:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-05-26 03:13:24 -04:00
finish_reason = " stop "
2023-05-10 16:12:17 -04:00
if self . verbose :
2023-11-06 09:16:36 -05:00
self . _ctx . print_timings ( )
2023-05-10 16:12:17 -04:00
2023-03-28 04:03:57 -04:00
if stream :
2023-05-18 11:35:59 -04:00
remaining_tokens = completion_tokens [ returned_tokens : ]
2024-02-23 12:23:24 -05:00
all_text = self . detokenize ( remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-18 11:35:59 -04:00
any_stop = [ s for s in stop_sequences if s in all_text ]
if len ( any_stop ) > 0 :
end = min ( all_text . index ( stop ) for stop in any_stop )
else :
end = len ( all_text )
2023-05-19 02:20:27 -04:00
token_end_position = 0
2023-05-18 11:35:59 -04:00
for token in remaining_tokens :
2024-02-23 12:23:24 -05:00
token_end_position + = len ( self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) )
2023-05-19 02:20:27 -04:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
if logprobs is not None :
2024-06-13 09:45:24 +02:00
if token == bos_token_id :
2023-11-20 22:50:59 -05:00
continue
2023-05-19 02:20:27 -04:00
token_str = self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
2024-02-23 12:23:24 -05:00
self . detokenize ( completion_tokens [ : returned_tokens ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-19 02:20:27 -04:00
)
token_offset = len ( prompt_tokens ) + returned_tokens - 1
2023-12-18 11:28:12 -08:00
logits = self . _scores [ token_offset , : ]
2023-12-18 18:40:36 -05:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 02:20:27 -04:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 11:59:33 -04:00
self . detokenize ( [ i ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-05-19 02:20:27 -04:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
" tokens " : [
self . detokenize ( [ token ] ) . decode ( " utf-8 " , errors = " ignore " )
] ,
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 02:20:27 -04:00
" top_logprobs " : [ top_logprob ] ,
}
if token_end_position > = end :
2023-05-18 11:35:59 -04:00
last_text = self . detokenize ( [ token ] )
2023-05-19 02:20:27 -04:00
if token_end_position == end - 1 :
2023-05-18 11:35:59 -04:00
break
2023-05-19 02:20:27 -04:00
returned_tokens + = 1
2023-05-18 11:35:59 -04:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : last_text [
2023-06-08 13:19:23 -04:00
: len ( last_text ) - ( token_end_position - end )
] . decode ( " utf-8 " , errors = " ignore " ) ,
2023-05-18 11:35:59 -04:00
" index " : 0 ,
2023-05-19 02:20:27 -04:00
" logprobs " : logprobs_or_none ,
2023-07-08 00:06:11 -04:00
" finish_reason " : None ,
}
] ,
}
2023-05-18 11:35:59 -04:00
break
returned_tokens + = 1
2023-03-28 04:03:57 -04:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
2023-05-18 11:35:59 -04:00
" model " : model_name ,
2023-03-28 04:03:57 -04:00
" choices " : [
{
2023-05-18 11:35:59 -04:00
" text " : self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
) ,
2023-03-28 04:03:57 -04:00
" index " : 0 ,
2023-05-19 02:20:27 -04:00
" logprobs " : logprobs_or_none ,
2023-03-28 04:03:57 -04:00
" finish_reason " : None ,
}
] ,
}
2023-10-19 02:55:56 -04:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : " " ,
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : finish_reason ,
}
] ,
}
2023-06-10 12:22:31 -04:00
if self . cache :
2023-05-26 03:03:01 -04:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-06-08 13:19:23 -04:00
print ( " Llama._create_completion: cache saved " , file = sys . stderr )
2023-03-28 04:03:57 -04:00
return
2023-06-10 12:22:31 -04:00
if self . cache :
2023-05-26 03:03:01 -04:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-04-26 14:37:06 +02:00
text_str = text . decode ( " utf-8 " , errors = " ignore " )
2023-03-23 16:25:13 -04:00
2023-03-23 05:33:06 -04:00
if echo :
2023-04-15 12:03:09 -04:00
text_str = prompt + text_str
2023-03-23 05:33:06 -04:00
2024-05-08 08:26:22 +02:00
if suffix_token_id < 0 and suffix is not None :
2023-04-15 12:03:09 -04:00
text_str = text_str + suffix
2023-03-23 05:33:06 -04:00
2023-04-12 14:05:11 -04:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
2023-03-23 15:51:05 -04:00
if logprobs is not None :
2023-05-19 02:20:27 -04:00
text_offset = 0 if echo else len ( prompt )
token_offset = 0 if echo else len ( prompt_tokens [ 1 : ] )
2023-04-14 09:59:33 -04:00
text_offsets : List [ int ] = [ ]
2023-05-19 02:20:27 -04:00
token_logprobs : List [ Optional [ float ] ] = [ ]
2023-04-14 09:59:33 -04:00
tokens : List [ str ] = [ ]
2023-05-19 02:20:27 -04:00
top_logprobs : List [ Optional [ Dict [ str , float ] ] ] = [ ]
if echo :
2024-06-04 16:18:38 +02:00
# Remove leading BOS token if exists
all_tokens = prompt_tokens [ 1 if prompt_tokens [ 0 ] == self . token_bos ( ) else 0 : ] + completion_tokens
2023-05-19 02:20:27 -04:00
else :
all_tokens = completion_tokens
2023-04-14 09:59:33 -04:00
all_token_strs = [
2024-02-23 12:23:24 -05:00
self . detokenize ( [ token ] , prev_tokens = all_tokens [ : i ] ) . decode ( " utf-8 " , errors = " ignore " )
for i , token in enumerate ( all_tokens )
2023-04-14 09:59:33 -04:00
]
2023-12-18 11:28:12 -08:00
all_logprobs = Llama . logits_to_logprobs ( self . _scores ) [ token_offset : ]
# TODO: may be able to change this loop to use np.take_along_dim
2023-12-22 14:03:29 +09:00
for idx , ( token , token_str , logprobs_token ) in enumerate (
zip ( all_tokens , all_token_strs , all_logprobs )
2023-04-14 09:59:33 -04:00
) :
2024-06-13 09:45:24 +02:00
if token == bos_token_id :
2023-11-20 22:50:59 -05:00
continue
2023-12-22 14:03:29 +09:00
text_offsets . append (
text_offset
+ len (
self . detokenize ( all_tokens [ : idx ] ) . decode (
" utf-8 " , errors = " ignore "
)
)
)
2023-04-14 09:59:33 -04:00
tokens . append ( token_str )
sorted_logprobs = list (
sorted (
zip ( logprobs_token , range ( len ( logprobs_token ) ) ) , reverse = True
)
)
2023-07-07 10:18:49 +00:00
token_logprobs . append ( logprobs_token [ int ( token ) ] )
2023-05-19 02:20:27 -04:00
top_logprob : Optional [ Dict [ str , float ] ] = {
2024-02-23 12:23:24 -05:00
self . detokenize ( [ i ] , prev_tokens = all_tokens [ : idx ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-04-14 09:59:33 -04:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
2023-05-19 02:20:27 -04:00
top_logprob . update ( { token_str : logprobs_token [ int ( token ) ] } )
2023-04-14 09:59:33 -04:00
top_logprobs . append ( top_logprob )
2023-05-19 02:20:27 -04:00
# Weird idosincracy of the OpenAI API where
# token_logprobs and top_logprobs are null for
# the first token.
if echo and len ( all_tokens ) > 0 :
token_logprobs [ 0 ] = None
top_logprobs [ 0 ] = None
2023-04-12 14:05:11 -04:00
logprobs_or_none = {
" tokens " : tokens ,
" text_offset " : text_offsets ,
" token_logprobs " : token_logprobs ,
" top_logprobs " : top_logprobs ,
}
2023-04-04 13:09:24 -04:00
2023-03-28 04:03:57 -04:00
yield {
2023-03-28 02:42:22 -04:00
" id " : completion_id ,
2023-03-23 05:33:06 -04:00
" object " : " text_completion " ,
2023-03-28 02:42:22 -04:00
" created " : created ,
2023-05-16 18:07:25 -04:00
" model " : model_name ,
2023-03-23 05:33:06 -04:00
" choices " : [
{
2023-04-15 12:03:09 -04:00
" text " : text_str ,
2023-03-23 05:33:06 -04:00
" index " : 0 ,
2023-04-12 14:05:11 -04:00
" logprobs " : logprobs_or_none ,
2023-03-23 05:33:06 -04:00
" finish_reason " : finish_reason ,
}
] ,
" usage " : {
2023-03-28 01:45:37 -04:00
" prompt_tokens " : len ( prompt_tokens ) ,
" completion_tokens " : len ( completion_tokens ) ,
" total_tokens " : len ( prompt_tokens ) + len ( completion_tokens ) ,
2023-03-23 05:33:06 -04:00
} ,
}
2023-04-01 13:01:27 -04:00
def create_completion (
self ,
2023-11-08 04:48:51 +01:00
prompt : Union [ str , List [ int ] ] ,
2023-04-01 13:01:27 -04:00
suffix : Optional [ str ] = None ,
2023-11-10 02:49:27 -05:00
max_tokens : Optional [ int ] = 16 ,
2023-04-01 13:01:27 -04:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-01 13:01:27 -04:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 13:19:23 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-04-01 13:01:27 -04:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = None ,
2023-06-08 13:19:23 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 04:01:36 -05:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 04:48:51 +01:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-04-01 13:01:27 -04:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-10 02:49:27 -05:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-04-01 13:01:27 -04:00
temperature : The temperature to use for sampling .
2023-11-24 03:24:19 -05:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-01 13:01:27 -04:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 03:24:19 -05:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-04-01 13:01:27 -04:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 03:24:19 -05:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-04-01 13:01:27 -04:00
stream : Whether to stream the results .
2023-11-24 03:24:19 -05:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-04-01 13:01:27 -04:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
completion_or_chunks = self . _create_completion (
prompt = prompt ,
suffix = suffix ,
2023-12-22 14:05:13 -05:00
max_tokens = - 1 if max_tokens is None else max_tokens ,
2023-04-01 13:01:27 -04:00
temperature = temperature ,
top_p = top_p ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 13:01:27 -04:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 01:30:18 -04:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 13:01:27 -04:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-07 23:37:28 -05:00
seed = seed ,
2023-05-11 21:56:19 -04:00
tfs_z = tfs_z ,
2023-05-08 21:21:25 -04:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 17:22:00 -04:00
model = model ,
2023-05-26 03:13:24 -04:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-09-28 22:42:03 -04:00
grammar = grammar ,
2023-11-21 04:01:36 -05:00
logit_bias = logit_bias ,
2023-04-01 13:01:27 -04:00
)
if stream :
2023-11-08 04:48:51 +01:00
chunks : Iterator [ CreateCompletionStreamResponse ] = completion_or_chunks
2023-04-01 13:01:27 -04:00
return chunks
completion : Completion = next ( completion_or_chunks ) # type: ignore
return completion
2023-03-28 04:03:57 -04:00
def __call__ (
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
2023-12-22 14:05:13 -05:00
max_tokens : Optional [ int ] = 16 ,
2023-03-28 04:03:57 -04:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-28 04:03:57 -04:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 13:19:23 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-28 04:03:57 -04:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = None ,
2023-06-08 13:19:23 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 04:01:36 -05:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 04:48:51 +01:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-03-28 04:03:57 -04:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-24 03:24:19 -05:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-03-28 04:03:57 -04:00
temperature : The temperature to use for sampling .
2023-11-24 03:24:19 -05:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-03-28 04:03:57 -04:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 03:24:19 -05:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-03-28 04:03:57 -04:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 03:24:19 -05:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-03-28 04:03:57 -04:00
stream : Whether to stream the results .
2023-11-24 03:24:19 -05:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-03-28 04:03:57 -04:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
2023-04-01 13:01:27 -04:00
return self . create_completion (
2023-03-28 04:03:57 -04:00
prompt = prompt ,
suffix = suffix ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-03-28 04:03:57 -04:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 01:30:18 -04:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-03-28 04:03:57 -04:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-07 23:37:28 -05:00
seed = seed ,
2023-05-11 21:56:19 -04:00
tfs_z = tfs_z ,
2023-05-08 21:21:25 -04:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 17:22:00 -04:00
model = model ,
2023-05-26 03:13:24 -04:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 15:08:54 -04:00
grammar = grammar ,
2023-11-21 04:01:36 -05:00
logit_bias = logit_bias ,
2023-03-28 04:03:57 -04:00
)
2023-04-03 20:12:44 -04:00
def create_chat_completion (
self ,
2023-09-29 19:52:04 -04:00
messages : List [ ChatCompletionRequestMessage ] ,
2023-07-19 03:48:20 -04:00
functions : Optional [ List [ ChatCompletionFunction ] ] = None ,
2023-11-08 04:48:51 +01:00
function_call : Optional [ ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ ChatCompletionTool ] ] = None ,
tool_choice : Optional [ ChatCompletionToolChoiceOption ] = None ,
2023-06-08 13:19:23 -04:00
temperature : float = 0.2 ,
2023-04-03 20:12:44 -04:00
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-03 20:12:44 -04:00
stream : bool = False ,
2023-06-08 13:19:23 -04:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-07 23:37:28 -05:00
seed : Optional [ int ] = None ,
2023-11-08 00:07:16 -05:00
response_format : Optional [ ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 02:49:27 -05:00
max_tokens : Optional [ int ] = None ,
2023-06-08 13:19:23 -04:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
2023-04-03 20:12:44 -04:00
repeat_penalty : float = 1.1 ,
2023-06-08 13:19:23 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
2023-06-09 13:13:08 -04:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 15:08:54 -04:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 03:59:46 -05:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2024-02-23 12:23:24 -05:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
) - > Union [
CreateChatCompletionResponse , Iterator [ CreateChatCompletionStreamResponse ]
] :
2023-04-03 20:24:20 -04:00
""" Generate a chat completion from a list of messages.
Args :
messages : A list of messages to generate a response for .
2023-11-24 03:24:19 -05:00
functions : A list of functions to use for the chat completion .
function_call : A function call to use for the chat completion .
tools : A list of tools to use for the chat completion .
tool_choice : A tool choice to use for the chat completion .
2023-04-03 20:24:20 -04:00
temperature : The temperature to use for sampling .
2023-11-24 03:24:19 -05:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-03 20:24:20 -04:00
stream : Whether to stream the results .
stop : A list of strings to stop generation when encountered .
2023-11-24 03:24:19 -05:00
seed : The seed to use for sampling .
response_format : The response format to use for the chat completion . Use { " type " : " json_object " } to contstrain output to only valid json .
2023-11-10 02:49:27 -05:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-11-24 03:24:19 -05:00
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
2023-04-03 20:24:20 -04:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 03:24:19 -05:00
tfs_z : The tail - free sampling parameter .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The mirostat sampling tau parameter .
mirostat_eta : The mirostat sampling eta parameter .
model : The name to use for the model in the completion object .
logits_processor : A list of logits processors to use .
grammar : A grammar to use .
logit_bias : A logit bias to use .
2023-04-03 20:24:20 -04:00
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
2024-05-09 15:49:09 +02:00
handler = self . chat_handler or self . _chat_handlers . get ( self . chat_format ) or llama_chat_format . get_chat_completion_handler (
2023-11-08 04:48:51 +01:00
self . chat_format
)
2023-11-03 02:12:14 -04:00
return handler (
2023-11-08 04:48:51 +01:00
llama = self ,
2023-09-29 19:52:04 -04:00
messages = messages ,
2023-11-03 02:12:14 -04:00
functions = functions ,
function_call = function_call ,
2023-11-08 04:48:51 +01:00
tools = tools ,
tool_choice = tool_choice ,
2023-04-03 20:12:44 -04:00
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2024-04-10 03:41:55 -04:00
logprobs = logprobs ,
top_logprobs = top_logprobs ,
2023-04-03 20:12:44 -04:00
stream = stream ,
2023-09-29 19:52:04 -04:00
stop = stop ,
2023-11-07 23:37:28 -05:00
seed = seed ,
2023-11-08 00:07:16 -05:00
response_format = response_format ,
2023-04-03 20:12:44 -04:00
max_tokens = max_tokens ,
2023-05-08 01:30:18 -04:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
2023-09-29 19:52:04 -04:00
repeat_penalty = repeat_penalty ,
2023-05-11 21:56:19 -04:00
tfs_z = tfs_z ,
2023-05-08 21:21:25 -04:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 17:22:00 -04:00
model = model ,
2023-06-09 13:13:08 -04:00
logits_processor = logits_processor ,
2023-08-08 15:08:54 -04:00
grammar = grammar ,
2023-11-21 03:59:46 -05:00
logit_bias = logit_bias ,
2023-04-03 20:12:44 -04:00
)
2024-02-12 15:56:07 -05:00
def create_chat_completion_openai_v1 (
self ,
* args : Any ,
* * kwargs : Any ,
) :
""" Generate a chat completion with return type based on the the OpenAI v1 API.
OpenAI python package is required to use this method .
You can install it with ` pip install openai ` .
Args :
* args : Positional arguments to pass to create_chat_completion .
* * kwargs : Keyword arguments to pass to create_chat_completion .
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
try :
from openai . types . chat import ChatCompletion , ChatCompletionChunk
2024-02-21 16:25:10 -05:00
stream = kwargs . get ( " stream " , False ) # type: ignore
2024-02-12 15:56:07 -05:00
assert isinstance ( stream , bool )
if stream :
2024-02-21 16:25:10 -05:00
return ( ChatCompletionChunk ( * * chunk ) for chunk in self . create_chat_completion ( * args , * * kwargs ) ) # type: ignore
2024-02-12 15:56:07 -05:00
else :
2024-02-21 16:25:10 -05:00
return ChatCompletion ( * * self . create_chat_completion ( * args , * * kwargs ) ) # type: ignore
2024-02-12 15:56:07 -05:00
except ImportError :
raise ImportError (
" To use create_chat_completion_openai_v1, you must install the openai package. "
" You can install it with `pip install openai`. "
)
2023-04-05 06:52:17 -04:00
def __getstate__ ( self ) :
return dict (
model_path = self . model_path ,
2023-09-28 22:42:03 -04:00
# Model Params
n_gpu_layers = self . model_params . n_gpu_layers ,
2024-01-15 12:49:20 -05:00
split_mode = self . model_params . split_mode ,
2023-09-28 22:42:03 -04:00
main_gpu = self . model_params . main_gpu ,
tensor_split = self . tensor_split ,
vocab_only = self . model_params . vocab_only ,
use_mmap = self . model_params . use_mmap ,
use_mlock = self . model_params . use_mlock ,
2024-01-15 17:29:29 +00:00
kv_overrides = self . kv_overrides ,
2023-09-28 22:42:03 -04:00
# Context Params
seed = self . context_params . seed ,
n_ctx = self . context_params . n_ctx ,
2023-04-05 06:52:17 -04:00
n_batch = self . n_batch ,
2023-09-28 22:42:03 -04:00
n_threads = self . context_params . n_threads ,
n_threads_batch = self . context_params . n_threads_batch ,
2023-11-02 13:40:20 -04:00
rope_scaling_type = self . context_params . rope_scaling_type ,
2024-04-01 22:19:28 +08:00
pooling_type = self . context_params . pooling_type ,
2023-09-28 22:42:03 -04:00
rope_freq_base = self . context_params . rope_freq_base ,
rope_freq_scale = self . context_params . rope_freq_scale ,
2023-11-02 13:40:20 -04:00
yarn_ext_factor = self . context_params . yarn_ext_factor ,
yarn_attn_factor = self . context_params . yarn_attn_factor ,
yarn_beta_fast = self . context_params . yarn_beta_fast ,
yarn_beta_slow = self . context_params . yarn_beta_slow ,
yarn_orig_ctx = self . context_params . yarn_orig_ctx ,
2023-09-28 22:42:03 -04:00
logits_all = self . context_params . logits_all ,
2024-03-06 01:32:00 -05:00
embedding = self . context_params . embeddings ,
2024-04-01 22:19:28 +08:00
offload_kqv = self . context_params . offload_kqv ,
2024-04-30 09:32:47 -04:00
flash_attn = self . context_params . flash_attn ,
2023-09-28 22:42:03 -04:00
# Sampling Params
last_n_tokens_size = self . last_n_tokens_size ,
# LoRA Params
2023-04-18 10:20:46 -04:00
lora_base = self . lora_base ,
2023-09-28 22:42:03 -04:00
lora_scale = self . lora_scale ,
2023-04-18 01:43:44 -04:00
lora_path = self . lora_path ,
2023-09-28 22:42:03 -04:00
# Backend Params
numa = self . numa ,
2023-09-29 19:52:04 -04:00
# Chat Format Params
chat_format = self . chat_format ,
2023-11-08 04:48:51 +01:00
chat_handler = self . chat_handler ,
2024-04-01 22:19:28 +08:00
# Speculative Decidng
draft_model = self . draft_model ,
# KV cache quantization
type_k = self . context_params . type_k ,
type_v = self . context_params . type_v ,
2023-09-28 22:42:03 -04:00
# Misc
2024-06-13 09:45:24 +02:00
spm_infill = self . spm_infill ,
2023-09-28 22:42:03 -04:00
verbose = self . verbose ,
2023-04-05 06:52:17 -04:00
)
def __setstate__ ( self , state ) :
2024-04-01 22:19:28 +08:00
self . __init__ ( * * state )
2023-04-05 06:52:17 -04:00
2023-04-24 17:51:25 -04:00
def save_state ( self ) - > LlamaState :
2023-11-06 09:16:36 -05:00
assert self . _ctx . ctx is not None
2023-06-08 13:19:23 -04:00
if self . verbose :
print ( " Llama.save_state: saving llama state " , file = sys . stderr )
2023-11-06 09:16:36 -05:00
state_size = llama_cpp . llama_get_state_size ( self . _ctx . ctx )
2023-06-08 13:19:23 -04:00
if self . verbose :
print ( f " Llama.save_state: got state size: { state_size } " , file = sys . stderr )
2024-02-23 11:24:53 -05:00
llama_state = ( ctypes . c_uint8 * int ( state_size ) ) ( )
2023-06-08 13:19:23 -04:00
if self . verbose :
print ( " Llama.save_state: allocated state " , file = sys . stderr )
2023-11-06 09:16:36 -05:00
n_bytes = llama_cpp . llama_copy_state_data ( self . _ctx . ctx , llama_state )
2023-06-08 13:19:23 -04:00
if self . verbose :
print ( f " Llama.save_state: copied llama state: { n_bytes } " , file = sys . stderr )
2023-05-03 09:33:50 -04:00
if int ( n_bytes ) > int ( state_size ) :
2023-04-24 17:51:25 -04:00
raise RuntimeError ( " Failed to copy llama state data " )
2024-02-23 11:24:53 -05:00
llama_state_compact = ( ctypes . c_uint8 * int ( n_bytes ) ) ( )
2023-05-03 09:33:50 -04:00
llama_cpp . ctypes . memmove ( llama_state_compact , llama_state , int ( n_bytes ) )
2023-05-03 10:28:10 -04:00
if self . verbose :
2023-05-04 21:58:36 -04:00
print (
f " Llama.save_state: saving { n_bytes } bytes of llama state " ,
file = sys . stderr ,
)
2023-04-24 17:51:25 -04:00
return LlamaState (
2024-04-17 09:06:50 -05:00
scores = self . _scores . copy ( ) ,
2023-06-29 00:40:47 -04:00
input_ids = self . input_ids . copy ( ) ,
n_tokens = self . n_tokens ,
2023-06-13 12:03:31 +02:00
llama_state = bytes ( llama_state_compact ) ,
2023-05-03 09:33:50 -04:00
llama_state_size = n_bytes ,
2023-04-24 17:51:25 -04:00
)
def load_state ( self , state : LlamaState ) - > None :
2023-11-06 09:16:36 -05:00
assert self . _ctx . ctx is not None
2024-04-17 09:06:50 -05:00
# Only filling in up to `n_tokens` and then zero-ing out the rest
self . scores [ : state . n_tokens , : ] = state . scores . copy ( )
self . scores [ state . n_tokens : , : ] = 0.0
2023-06-29 00:40:47 -04:00
self . input_ids = state . input_ids . copy ( )
self . n_tokens = state . n_tokens
2023-05-03 09:33:50 -04:00
state_size = state . llama_state_size
2024-02-21 16:25:38 -05:00
LLamaStateArrayType = ctypes . c_uint8 * state_size
2023-06-13 12:03:31 +02:00
llama_state = LLamaStateArrayType . from_buffer_copy ( state . llama_state )
2023-11-06 09:16:36 -05:00
if llama_cpp . llama_set_state_data ( self . _ctx . ctx , llama_state ) != state_size :
2023-04-24 17:51:25 -04:00
raise RuntimeError ( " Failed to set llama state data " )
2023-05-20 08:13:41 -04:00
def n_ctx ( self ) - > int :
""" Return the context window size. """
2023-11-06 09:16:36 -05:00
return self . _ctx . n_ctx ( )
2023-05-20 08:13:41 -04:00
def n_embd ( self ) - > int :
""" Return the embedding size. """
2023-11-06 09:16:36 -05:00
return self . _model . n_embd ( )
2023-05-20 08:13:41 -04:00
def n_vocab ( self ) - > int :
""" Return the vocabulary size. """
2023-11-06 09:16:36 -05:00
return self . _model . n_vocab ( )
2023-05-20 08:13:41 -04:00
2024-02-08 09:07:03 +08:00
def tokenizer ( self ) - > LlamaTokenizer :
""" Return the llama tokenizer for this model. """
2023-05-25 14:11:33 -04:00
return LlamaTokenizer ( self )
2023-04-05 06:52:17 -04:00
2023-08-24 00:17:00 -04:00
def token_eos ( self ) - > int :
2023-04-01 17:29:30 -04:00
""" Return the end-of-sequence token. """
2023-11-06 09:16:36 -05:00
return self . _model . token_eos ( )
2023-04-01 17:29:30 -04:00
2023-08-24 00:17:00 -04:00
def token_bos ( self ) - > int :
2023-04-01 17:29:30 -04:00
""" Return the beginning-of-sequence token. """
2023-11-06 09:16:36 -05:00
return self . _model . token_bos ( )
2023-04-12 14:05:11 -04:00
2023-08-24 00:17:00 -04:00
def token_nl ( self ) - > int :
2023-05-17 01:53:26 -04:00
""" Return the newline token. """
2023-11-06 09:16:36 -05:00
return self . _model . token_nl ( )
2023-05-17 01:53:26 -04:00
2024-04-25 20:32:44 -05:00
def pooling_type ( self ) - > str :
""" Return the pooling type. """
return self . _ctx . pooling_type ( )
feat: Add `.close()` method to `Llama` class to explicitly free model from memory (#1513)
* feat: add explicit methods to free model
This commit introduces a `close` method to both `Llama` and `_LlamaModel`,
allowing users to explicitly free the model from RAM/VRAM.
The previous implementation relied on the destructor of `_LlamaModel` to free
the model. However, in Python, the timing of destructor calls is unclear—for
instance, the `del` statement does not guarantee immediate invocation of the
destructor.
This commit provides an explicit method to release the model, which works
immediately and allows the user to load another model without memory issues.
Additionally, this commit implements a context manager in the `Llama` class,
enabling the automatic closure of the `Llama` object when used with the `with`
statement.
* feat: Implement ContextManager in _LlamaModel, _LlamaContext, and _LlamaBatch
This commit enables automatic resource management by
implementing the `ContextManager` protocol in `_LlamaModel`,
`_LlamaContext`, and `_LlamaBatch`. This ensures that
resources are properly managed and released within a `with`
statement, enhancing robustness and safety in resource handling.
* feat: add ExitStack for Llama's internal class closure
This update implements ExitStack to manage and close internal
classes in Llama, enhancing efficient and safe resource
management.
* Use contextlib ExitStack and closing
* Explicitly free model when closing resources on server
---------
Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-06-13 02:16:14 -06:00
def close ( self ) - > None :
""" Explicitly free the model from memory. """
self . _stack . close ( )
2023-04-12 14:05:11 -04:00
@staticmethod
2023-12-16 15:59:26 -08:00
def logits_to_logprobs (
2023-12-18 11:28:12 -08:00
logits : Union [ npt . NDArray [ np . single ] , List ] , axis : int = - 1
2023-12-16 15:59:26 -08:00
) - > npt . NDArray [ np . single ] :
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
logits_maxs : np . ndarray = np . amax ( logits , axis = axis , keepdims = True )
if logits_maxs . ndim > 0 :
logits_maxs [ ~ np . isfinite ( logits_maxs ) ] = 0
elif not np . isfinite ( logits_maxs ) :
logits_maxs = 0
subtract_maxs = np . subtract ( logits , logits_maxs , dtype = np . single )
exp = np . exp ( subtract_maxs )
# Suppress warnings about log of zero
2023-12-18 11:28:12 -08:00
with np . errstate ( divide = " ignore " ) :
2023-12-16 15:59:26 -08:00
summed = np . sum ( exp , axis = axis , keepdims = True )
out = np . log ( summed )
return subtract_maxs - out
2023-05-07 19:31:26 -04:00
@staticmethod
2023-05-19 11:59:33 -04:00
def longest_token_prefix ( a : Sequence [ int ] , b : Sequence [ int ] ) :
2023-05-07 19:31:26 -04:00
longest_prefix = 0
for _a , _b in zip ( a , b ) :
if _a == _b :
longest_prefix + = 1
else :
break
return longest_prefix
2023-05-25 14:11:33 -04:00
2024-02-21 16:25:10 -05:00
@classmethod
def from_pretrained (
cls ,
repo_id : str ,
filename : Optional [ str ] ,
2024-02-22 00:10:23 -05:00
local_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
2024-02-21 16:25:10 -05:00
local_dir_use_symlinks : Union [ bool , Literal [ " auto " ] ] = " auto " ,
2024-02-22 00:10:23 -05:00
cache_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
2024-02-21 16:25:10 -05:00
* * kwargs : Any ,
) - > " Llama " :
""" Create a Llama model from a pretrained model name or path.
This method requires the huggingface - hub package .
You can install it with ` pip install huggingface - hub ` .
Args :
repo_id : The model repo id .
filename : A filename or glob pattern to match the model file in the repo .
local_dir : The local directory to save the model to .
local_dir_use_symlinks : Whether to use symlinks when downloading the model .
* * kwargs : Additional keyword arguments to pass to the Llama constructor .
Returns :
A Llama model . """
try :
from huggingface_hub import hf_hub_download , HfFileSystem
from huggingface_hub . utils import validate_repo_id
except ImportError :
raise ImportError (
" Llama.from_pretrained requires the huggingface-hub package. "
" You can install it with `pip install huggingface-hub`. "
)
validate_repo_id ( repo_id )
hffs = HfFileSystem ( )
files = [
file [ " name " ] if isinstance ( file , dict ) else file
for file in hffs . ls ( repo_id )
]
# split each file into repo_id, subfolder, filename
file_list : List [ str ] = [ ]
for file in files :
rel_path = Path ( file ) . relative_to ( repo_id )
file_list . append ( str ( rel_path ) )
2023-05-25 14:11:33 -04:00
2024-02-21 16:25:10 -05:00
matching_files = [ file for file in file_list if fnmatch . fnmatch ( file , filename ) ] # type: ignore
if len ( matching_files ) == 0 :
raise ValueError (
f " No file found in { repo_id } that match { filename } \n \n "
f " Available Files: \n { json . dumps ( file_list ) } "
)
if len ( matching_files ) > 1 :
raise ValueError (
f " Multiple files found in { repo_id } matching { filename } \n \n "
f " Available Files: \n { json . dumps ( files ) } "
)
( matching_file , ) = matching_files
subfolder = str ( Path ( matching_file ) . parent )
filename = Path ( matching_file ) . name
# download the file
hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
2024-02-22 00:10:23 -05:00
local_dir = local_dir ,
2024-02-21 16:25:10 -05:00
local_dir_use_symlinks = local_dir_use_symlinks ,
2024-02-22 00:10:23 -05:00
cache_dir = cache_dir ,
2024-02-21 16:25:10 -05:00
)
2024-02-22 00:10:23 -05:00
if local_dir is None :
model_path = hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = local_dir ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cache_dir ,
local_files_only = True ,
)
else :
model_path = os . path . join ( local_dir , filename )
2024-02-21 16:25:10 -05:00
return cls (
model_path = model_path ,
* * kwargs ,
)
2024-02-08 09:07:03 +08:00
2024-01-17 09:16:13 -05:00
class LlamaState :
def __init__ (
self ,
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
n_tokens : int ,
llama_state : bytes ,
llama_state_size : int ,
) :
self . input_ids = input_ids
self . scores = scores
self . n_tokens = n_tokens
self . llama_state = llama_state
self . llama_state_size = llama_state_size
LogitsProcessor = Callable [
[ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , npt . NDArray [ np . single ]
]
class LogitsProcessorList ( List [ LogitsProcessor ] ) :
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , scores : npt . NDArray [ np . single ]
) - > npt . NDArray [ np . single ] :
for processor in self :
scores = processor ( input_ids , scores )
return scores
StoppingCriteria = Callable [ [ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , bool ]
class StoppingCriteriaList ( List [ StoppingCriteria ] ) :
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , logits : npt . NDArray [ np . single ]
) - > bool :
return any ( [ stopping_criteria ( input_ids , logits ) for stopping_criteria in self ] )
2024-05-14 22:50:53 +09:00
class MinTokensLogitsProcessor ( LogitsProcessor ) :
def __init__ ( self , min_tokens : int , token_eos : int ) :
self . min_tokens = min_tokens
self . token_eos = token_eos
self . prompt_tokens = None
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , scores : npt . NDArray [ np . single ]
) - > npt . NDArray [ np . single ] :
if self . prompt_tokens is None :
self . prompt_tokens = len ( input_ids )
if len ( input_ids ) - self . prompt_tokens < self . min_tokens :
scores [ self . token_eos ] = - np . inf
return scores