2023-03-24 19:47:17 +00:00
import os
2023-04-04 17:09:24 +00:00
import sys
2023-03-23 09:33:06 +00:00
import uuid
import time
import multiprocessing
2023-05-25 18:04:54 +00:00
from typing import (
List ,
Optional ,
Union ,
Generator ,
Sequence ,
Iterator ,
Deque ,
Callable ,
)
2024-01-17 14:09:12 +00:00
from collections import deque
2023-03-23 09:33:06 +00:00
2023-07-15 19:11:01 +00:00
import ctypes
2023-03-23 09:33:06 +00:00
2023-04-01 17:01:27 +00:00
from . llama_types import *
2023-08-06 17:21:37 +00:00
from . llama_grammar import LlamaGrammar
2024-01-17 14:09:12 +00:00
from . llama_cache import (
BaseLlamaCache ,
LlamaCache , # type: ignore
LlamaDiskCache , # type: ignore
LlamaRAMCache , # type: ignore
)
2023-11-08 03:48:51 +00:00
import llama_cpp . llama_cpp as llama_cpp
2023-11-03 06:12:14 +00:00
import llama_cpp . llama_chat_format as llama_chat_format
2023-03-23 09:33:06 +00:00
2023-05-26 20:12:45 +00:00
import numpy as np
import numpy . typing as npt
2023-11-03 16:55:55 +00:00
from . _utils import suppress_stdout_stderr
2024-01-17 14:14:00 +00:00
from . _internals import (
_LlamaModel , # type: ignore
_LlamaContext , # type: ignore
_LlamaBatch , # type: ignore
_LlamaTokenDataArray , # type: ignore
)
2023-07-18 23:27:41 +00:00
2023-09-29 02:42:03 +00:00
2023-04-24 21:51:25 +00:00
class LlamaState :
def __init__ (
2023-06-08 17:19:23 +00:00
self ,
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
2023-06-29 04:40:47 +00:00
n_tokens : int ,
2023-06-13 10:03:31 +00:00
llama_state : bytes ,
2023-06-08 17:19:23 +00:00
llama_state_size : int ,
2023-04-24 21:51:25 +00:00
) :
2023-05-26 20:12:45 +00:00
self . input_ids = input_ids
self . scores = scores
2023-06-29 04:40:47 +00:00
self . n_tokens = n_tokens
2023-04-24 21:51:25 +00:00
self . llama_state = llama_state
2023-05-03 13:33:50 +00:00
self . llama_state_size = llama_state_size
2023-04-15 16:03:09 +00:00
2023-07-18 23:27:41 +00:00
LogitsProcessor = Callable [
[ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , npt . NDArray [ np . single ]
]
2023-05-25 18:04:54 +00:00
class LogitsProcessorList ( List [ LogitsProcessor ] ) :
2023-07-18 23:27:41 +00:00
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , scores : npt . NDArray [ np . single ]
) - > npt . NDArray [ np . single ] :
2023-05-25 18:04:54 +00:00
for processor in self :
scores = processor ( input_ids , scores )
return scores
2023-07-18 23:27:41 +00:00
StoppingCriteria = Callable [ [ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , bool ]
2023-05-25 18:04:54 +00:00
class StoppingCriteriaList ( List [ StoppingCriteria ] ) :
2023-07-18 23:27:41 +00:00
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , logits : npt . NDArray [ np . single ]
) - > bool :
2023-05-25 18:04:54 +00:00
return any ( [ stopping_criteria ( input_ids , logits ) for stopping_criteria in self ] )
2023-03-24 18:35:41 +00:00
2023-03-23 09:33:06 +00:00
class Llama :
2023-03-24 22:57:59 +00:00
""" High-level Python wrapper for a llama.cpp model. """
2023-09-14 03:00:43 +00:00
__backend_initialized = False
2023-03-23 09:33:06 +00:00
def __init__ (
self ,
model_path : str ,
2023-09-14 01:19:47 +00:00
* ,
2023-09-29 02:42:03 +00:00
# Model Params
2023-06-08 17:19:23 +00:00
n_gpu_layers : int = 0 ,
2024-01-15 17:49:20 +00:00
split_mode : int = llama_cpp . LLAMA_SPLIT_LAYER ,
2023-09-14 01:20:26 +00:00
main_gpu : int = 0 ,
tensor_split : Optional [ List [ float ] ] = None ,
2023-09-29 02:42:03 +00:00
vocab_only : bool = False ,
use_mmap : bool = True ,
use_mlock : bool = False ,
2024-01-15 17:29:29 +00:00
kv_overrides : Optional [ Dict [ str , Union [ bool , int , float ] ] ] = None ,
2023-09-29 02:42:03 +00:00
# Context Params
seed : int = llama_cpp . LLAMA_DEFAULT_SEED ,
n_ctx : int = 512 ,
n_batch : int = 512 ,
n_threads : Optional [ int ] = None ,
n_threads_batch : Optional [ int ] = None ,
2023-11-02 17:40:20 +00:00
rope_scaling_type : Optional [ int ] = llama_cpp . LLAMA_ROPE_SCALING_UNSPECIFIED ,
2023-09-29 20:03:57 +00:00
rope_freq_base : float = 0.0 ,
rope_freq_scale : float = 0.0 ,
2023-11-03 15:34:50 +00:00
yarn_ext_factor : float = - 1.0 ,
2023-11-02 17:40:20 +00:00
yarn_attn_factor : float = 1.0 ,
yarn_beta_fast : float = 32.0 ,
yarn_beta_slow : float = 1.0 ,
yarn_orig_ctx : int = 0 ,
2023-09-14 01:20:26 +00:00
mul_mat_q : bool = True ,
2023-03-23 09:33:06 +00:00
logits_all : bool = False ,
2023-03-25 20:26:23 +00:00
embedding : bool = False ,
2023-12-18 20:36:09 +00:00
offload_kqv : bool = False ,
2023-09-29 02:42:03 +00:00
# Sampling Params
2023-04-01 17:01:27 +00:00
last_n_tokens_size : int = 64 ,
2023-09-29 02:42:03 +00:00
# LoRA Params
2023-06-08 17:19:23 +00:00
lora_base : Optional [ str ] = None ,
2023-09-29 02:42:03 +00:00
lora_scale : float = 1.0 ,
2023-06-08 17:19:23 +00:00
lora_path : Optional [ str ] = None ,
2023-09-29 02:42:03 +00:00
# Backend Params
2023-09-14 03:00:43 +00:00
numa : bool = False ,
2023-09-29 23:52:04 +00:00
# Chat Format Params
chat_format : str = " llama-2 " ,
2023-11-08 03:48:51 +00:00
chat_handler : Optional [ llama_chat_format . LlamaChatCompletionHandler ] = None ,
2023-09-29 02:42:03 +00:00
# Misc
2023-04-04 17:09:24 +00:00
verbose : bool = True ,
2023-09-29 02:42:03 +00:00
# Extra Params
* * kwargs , # type: ignore
2023-04-01 17:01:27 +00:00
) :
2023-03-24 22:57:59 +00:00
""" Load a llama.cpp model from `model_path`.
2023-12-16 23:59:26 +00:00
2023-11-23 04:10:04 +00:00
Examples :
Basic usage
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . )
>> > print ( model ( " The quick brown fox jumps " , stop = [ " . " ] ) [ " choices " ] [ 0 ] [ " text " ] )
the lazy dog
Loading a chat model
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . chat_format = " llama-2 " ,
. . . )
>> > print ( model . create_chat_completion (
. . . messages = [ {
. . . " role " : " user " ,
. . . " content " : " what is the meaning of life? "
. . . } ]
. . . ) )
2023-03-24 22:57:59 +00:00
Args :
2023-03-25 16:33:18 +00:00
model_path : Path to the model .
2023-08-12 10:41:47 +00:00
n_gpu_layers : Number of layers to offload to GPU ( - ngl ) . If - 1 , all layers are offloaded .
2024-01-15 17:49:20 +00:00
split_mode : How to split the model across GPUs . See llama_cpp . LLAMA_SPLIT_ * for options .
main_gpu : main_gpu interpretation depends on split_mode : LLAMA_SPLIT_NONE : the GPU that is used for the entire model . LLAMA_SPLIT_ROW : the GPU that is used for small tensors and intermediate results . LLAMA_SPLIT_LAYER : ignored
2023-11-02 17:40:20 +00:00
tensor_split : How split tensors should be distributed across GPUs . If None , the model is not split .
vocab_only : Only load the vocabulary no weights .
use_mmap : Use mmap if possible .
use_mlock : Force the system to keep the model in RAM .
2024-01-15 17:29:29 +00:00
kv_overrides : Key - value overrides for the model .
2023-11-26 20:56:40 +00:00
seed : RNG seed , - 1 for random
n_ctx : Text context , 0 = from model
n_batch : Prompt processing maximum batch size
n_threads : Number of threads to use for generation
n_threads_batch : Number of threads to use for batch processing
rope_scaling_type : RoPE scaling type , from ` enum llama_rope_scaling_type ` . ref : https : / / github . com / ggerganov / llama . cpp / pull / 2054
rope_freq_base : RoPE base frequency , 0 = from model
rope_freq_scale : RoPE frequency scaling factor , 0 = from model
yarn_ext_factor : YaRN extrapolation mix factor , negative = from model
yarn_attn_factor : YaRN magnitude scaling factor
yarn_beta_fast : YaRN low correction dim
yarn_beta_slow : YaRN high correction dim
yarn_orig_ctx : YaRN original context size
logits_all : Return logits for all tokens , not just the last token . Must be True for completion to return logprobs .
2023-03-25 20:26:23 +00:00
embedding : Embedding mode only .
2023-12-18 20:36:09 +00:00
offload_kqv : Offload K , Q , V to GPU .
2023-04-01 17:01:27 +00:00
last_n_tokens_size : Maximum number of tokens to keep in the last_n_tokens deque .
2023-04-18 14:20:46 +00:00
lora_base : Optional path to base model , useful if using a quantized base model and you want to apply LoRA to an f16 model .
2023-04-18 05:43:44 +00:00
lora_path : Path to a LoRA file to apply to the model .
2023-09-14 03:00:43 +00:00
numa : Enable NUMA support . ( NOTE : The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init )
2023-09-29 23:52:04 +00:00
chat_format : String specifying the chat format to use when calling create_chat_completion .
2023-11-08 03:48:51 +00:00
chat_handler : Optional chat handler to use when calling create_chat_completion .
2023-04-04 17:09:24 +00:00
verbose : Print verbose output to stderr .
2023-03-24 22:57:59 +00:00
Raises :
ValueError : If the model path does not exist .
Returns :
A Llama instance .
"""
2023-04-04 17:09:24 +00:00
self . verbose = verbose
2023-09-14 03:00:43 +00:00
2023-09-29 02:42:03 +00:00
self . numa = numa
2023-09-14 03:00:43 +00:00
if not Llama . __backend_initialized :
2023-11-03 17:02:15 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2023-09-29 02:42:03 +00:00
llama_cpp . llama_backend_init ( self . numa )
2023-09-14 03:00:43 +00:00
Llama . __backend_initialized = True
2023-03-23 09:33:06 +00:00
self . model_path = model_path
2023-09-29 02:42:03 +00:00
# Model Params
self . model_params = llama_cpp . llama_model_default_params ( )
self . model_params . n_gpu_layers = (
0x7FFFFFFF if n_gpu_layers == - 1 else n_gpu_layers
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
2024-01-15 17:49:20 +00:00
self . model_params . split_mode = split_mode
2023-09-29 02:42:03 +00:00
self . model_params . main_gpu = main_gpu
2023-07-15 19:11:01 +00:00
self . tensor_split = tensor_split
2023-12-22 20:12:27 +00:00
self . _c_tensor_split = None
2023-07-15 19:11:01 +00:00
if self . tensor_split is not None :
2023-10-15 17:51:51 +00:00
if len ( self . tensor_split ) > llama_cpp . LLAMA_MAX_DEVICES :
2023-11-06 14:16:36 +00:00
raise ValueError (
f " Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES= { llama_cpp . LLAMA_MAX_DEVICES } "
)
2023-07-18 23:27:41 +00:00
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
2023-09-14 00:00:42 +00:00
FloatArray = ctypes . c_float * llama_cpp . LLAMA_MAX_DEVICES
2023-07-18 23:27:41 +00:00
self . _c_tensor_split = FloatArray (
2023-09-29 02:42:03 +00:00
* tensor_split # type: ignore
2023-07-18 23:27:41 +00:00
) # keep a reference to the array so it is not gc'd
2023-09-29 02:42:03 +00:00
self . model_params . tensor_split = self . _c_tensor_split
self . model_params . vocab_only = vocab_only
self . model_params . use_mmap = use_mmap if lora_path is None else False
self . model_params . use_mlock = use_mlock
2023-07-15 19:11:01 +00:00
2024-01-15 17:29:29 +00:00
self . kv_overrides = kv_overrides
if kv_overrides is not None :
n_overrides = len ( kv_overrides )
self . _kv_overrides_array = llama_cpp . llama_model_kv_override * ( n_overrides + 1 )
self . _kv_overrides_array_keys = [ ]
for k , v in kv_overrides . items ( ) :
key_buf = ctypes . create_string_buffer ( k . encode ( " utf-8 " ) )
self . _kv_overrides_array_keys . append ( key_buf )
self . _kv_overrides_array [ i ] . key = key_buf
if isinstance ( v , int ) :
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_INT
self . _kv_overrides_array [ i ] . value . int_value = v
elif isinstance ( v , float ) :
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_FLOAT
self . _kv_overrides_array [ i ] . value . float_value = v
elif isinstance ( v , bool ) :
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_BOOL
self . _kv_overrides_array [ i ] . value . bool_value = v
else :
raise ValueError ( f " Unknown value type for { k } : { v } " )
self . _kv_overrides_array_sentinel_key = b ' \0 '
# null array sentinel
self . _kv_overrides_array [ n_overrides ] . key = self . _kv_overrides_array_sentinel_key
self . model_params . kv_overrides = self . _kv_overrides_array
2023-09-29 02:42:03 +00:00
self . n_batch = min ( n_ctx , n_batch ) # ???
self . n_threads = n_threads or max ( multiprocessing . cpu_count ( ) / / 2 , 1 )
self . n_threads_batch = n_threads_batch or max (
multiprocessing . cpu_count ( ) / / 2 , 1
)
# Context Params
self . context_params = llama_cpp . llama_context_default_params ( )
self . context_params . seed = seed
self . context_params . n_ctx = n_ctx
self . context_params . n_batch = self . n_batch
self . context_params . n_threads = self . n_threads
self . context_params . n_threads_batch = self . n_threads_batch
2023-11-02 17:40:20 +00:00
self . context_params . rope_scaling_type = (
2023-11-06 14:16:36 +00:00
rope_scaling_type
if rope_scaling_type is not None
else llama_cpp . LLAMA_ROPE_SCALING_UNSPECIFIED
2023-11-02 17:40:20 +00:00
)
2023-09-29 20:03:57 +00:00
self . context_params . rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self . context_params . rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
2023-11-02 17:40:20 +00:00
self . context_params . yarn_ext_factor = (
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
)
self . context_params . yarn_attn_factor = (
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
)
self . context_params . yarn_beta_fast = (
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
)
self . context_params . yarn_beta_slow = (
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
2023-11-06 14:16:36 +00:00
self . context_params . yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
2023-09-29 02:42:03 +00:00
self . context_params . mul_mat_q = mul_mat_q
self . context_params . logits_all = logits_all
self . context_params . embedding = embedding
2023-12-18 20:36:09 +00:00
self . context_params . offload_kqv = offload_kqv
2023-09-29 02:42:03 +00:00
# Sampling Params
2023-04-01 17:01:27 +00:00
self . last_n_tokens_size = last_n_tokens_size
2023-03-23 09:33:06 +00:00
2023-09-29 02:42:03 +00:00
self . cache : Optional [ BaseLlamaCache ] = None
2023-03-23 09:33:06 +00:00
2023-04-25 13:00:53 +00:00
self . lora_base = lora_base
2023-09-29 02:42:03 +00:00
self . lora_scale = lora_scale
2023-04-25 13:00:53 +00:00
self . lora_path = lora_path
2023-03-24 19:47:17 +00:00
if not os . path . exists ( model_path ) :
raise ValueError ( f " Model path does not exist: { model_path } " )
2023-11-06 14:16:36 +00:00
self . _model = _LlamaModel (
path_model = self . model_path , params = self . model_params , verbose = self . verbose
)
2023-12-16 23:59:50 +00:00
# Set the default value for the context and correct the batch
if n_ctx == 0 :
n_ctx = self . _model . n_ctx_train ( )
self . n_batch = min ( n_ctx , n_batch )
self . context_params . n_ctx = self . _model . n_ctx_train ( )
self . context_params . n_batch = self . n_batch
2023-03-23 09:33:06 +00:00
2023-11-06 14:16:36 +00:00
self . _ctx = _LlamaContext (
model = self . _model ,
params = self . context_params ,
verbose = self . verbose ,
)
2023-04-25 13:00:53 +00:00
2023-11-06 14:16:36 +00:00
self . _batch = _LlamaBatch (
n_tokens = self . n_batch ,
embd = 0 ,
n_seq_max = self . context_params . n_ctx ,
verbose = self . verbose ,
)
2023-11-03 00:13:57 +00:00
2023-04-19 03:45:25 +00:00
if self . lora_path :
2023-11-06 14:16:36 +00:00
if self . _model . apply_lora_from_file (
self . lora_path ,
2023-09-29 02:42:03 +00:00
self . lora_scale ,
2023-11-06 14:16:36 +00:00
self . lora_base ,
2023-09-14 01:11:52 +00:00
self . n_threads ,
2023-04-18 05:43:44 +00:00
) :
2023-04-19 03:45:25 +00:00
raise RuntimeError (
f " Failed to apply LoRA from lora path: { self . lora_path } to base path: { self . lora_base } "
)
2023-03-23 09:33:06 +00:00
2023-04-04 17:09:24 +00:00
if self . verbose :
print ( llama_cpp . llama_print_system_info ( ) . decode ( " utf-8 " ) , file = sys . stderr )
2023-11-03 06:12:14 +00:00
2023-09-29 23:52:04 +00:00
self . chat_format = chat_format
2023-11-08 03:48:51 +00:00
self . chat_handler = chat_handler
2023-04-04 17:09:24 +00:00
2023-05-23 21:56:21 +00:00
self . _n_vocab = self . n_vocab ( )
self . _n_ctx = self . n_ctx ( )
2023-11-06 14:16:36 +00:00
2023-08-24 04:17:00 +00:00
self . _token_nl = self . token_nl ( )
self . _token_eos = self . token_eos ( )
2023-11-06 14:16:36 +00:00
self . _candidates = _LlamaTokenDataArray ( n_vocab = self . _n_vocab )
2023-04-04 17:09:24 +00:00
2023-06-29 04:40:47 +00:00
self . n_tokens = 0
self . input_ids : npt . NDArray [ np . intc ] = np . ndarray ( ( n_ctx , ) , dtype = np . intc )
self . scores : npt . NDArray [ np . single ] = np . ndarray (
( n_ctx , self . _n_vocab ) , dtype = np . single
)
2023-11-06 14:16:36 +00:00
@property
def ctx ( self ) - > llama_cpp . llama_context_p :
assert self . _ctx . ctx is not None
return self . _ctx . ctx
@property
def model ( self ) - > llama_cpp . llama_model_p :
assert self . _model . model is not None
return self . _model . model
2023-06-29 04:40:47 +00:00
@property
def _input_ids ( self ) - > npt . NDArray [ np . intc ] :
return self . input_ids [ : self . n_tokens ]
@property
def _scores ( self ) - > npt . NDArray [ np . single ] :
return self . scores [ : self . n_tokens , : ]
@property
def eval_tokens ( self ) - > Deque [ int ] :
return deque ( self . input_ids [ : self . n_tokens ] . tolist ( ) , maxlen = self . _n_ctx )
@property
def eval_logits ( self ) - > Deque [ List [ float ] ] :
return deque (
self . scores [ : self . n_tokens , : ] . tolist ( ) ,
2023-09-30 20:02:35 +00:00
maxlen = self . _n_ctx if self . context_params . logits_all else 1 ,
2023-06-29 04:40:47 +00:00
)
2023-05-26 20:12:45 +00:00
2023-11-06 14:16:36 +00:00
def tokenize (
self , text : bytes , add_bos : bool = True , special : bool = False
) - > List [ int ] :
2023-03-28 05:45:37 +00:00
""" Tokenize a string.
Args :
text : The utf - 8 encoded string to tokenize .
2023-04-01 17:01:27 +00:00
Raises :
RuntimeError : If the tokenization failed .
2023-03-28 05:45:37 +00:00
Returns :
A list of tokens .
"""
2023-11-06 14:16:36 +00:00
return self . _model . tokenize ( text , add_bos , special )
2023-03-28 05:45:37 +00:00
2023-05-19 15:59:33 +00:00
def detokenize ( self , tokens : List [ int ] ) - > bytes :
2023-03-28 05:45:37 +00:00
""" Detokenize a list of tokens.
Args :
tokens : The list of tokens to detokenize .
Returns :
The detokenized string .
"""
2023-11-06 14:16:36 +00:00
return self . _model . detokenize ( tokens )
2023-03-28 05:45:37 +00:00
2023-06-08 17:19:23 +00:00
def set_cache ( self , cache : Optional [ BaseLlamaCache ] ) :
2023-04-15 16:03:09 +00:00
""" Set the cache.
Args :
cache : The cache to set .
"""
2023-04-24 23:54:41 +00:00
self . cache = cache
2023-04-15 16:03:09 +00:00
2023-11-08 16:09:41 +00:00
def set_seed ( self , seed : int ) :
""" Set the random seed.
Args :
seed : The random seed .
"""
assert self . _ctx . ctx is not None
llama_cpp . llama_set_rng_seed ( self . _ctx . ctx , seed )
2023-04-02 04:02:47 +00:00
def reset ( self ) :
""" Reset the model state. """
2023-06-29 04:40:47 +00:00
self . n_tokens = 0
2023-04-02 04:02:47 +00:00
2023-05-19 15:59:33 +00:00
def eval ( self , tokens : Sequence [ int ] ) :
2023-04-02 04:02:47 +00:00
""" Evaluate a list of tokens.
Args :
tokens : The list of tokens to evaluate .
"""
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
assert self . _batch . batch is not None
2023-11-10 09:41:19 +00:00
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
2023-04-02 04:02:47 +00:00
for i in range ( 0 , len ( tokens ) , self . n_batch ) :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
2023-11-10 09:41:19 +00:00
n_past = self . n_tokens
2023-04-24 19:47:54 +00:00
n_tokens = len ( batch )
2023-11-06 14:16:36 +00:00
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
2023-04-02 04:02:47 +00:00
)
2023-11-06 14:16:36 +00:00
self . _ctx . decode ( self . _batch )
2023-05-01 18:47:55 +00:00
# Save tokens
2023-11-10 10:15:41 +00:00
self . input_ids [ n_past : n_past + n_tokens ] = batch
2023-05-01 18:47:55 +00:00
# Save logits
2023-11-10 10:15:41 +00:00
rows = n_tokens
2023-06-29 04:45:46 +00:00
cols = self . _n_vocab
2023-07-08 04:05:10 +00:00
offset = (
2023-09-29 02:42:03 +00:00
0 if self . context_params . logits_all else n_tokens - 1
2023-07-08 04:05:10 +00:00
) # NOTE: Only save the last token logits if logits_all is False
2023-11-21 09:02:20 +00:00
self . scores [ n_past + offset : n_past + n_tokens , : ] . reshape ( - 1 ) [
:
] = self . _ctx . get_logits ( ) [ offset * cols : rows * cols ]
2023-06-29 04:40:47 +00:00
# Update n_tokens
self . n_tokens + = n_tokens
2023-05-01 18:47:55 +00:00
2023-11-06 14:16:36 +00:00
def sample (
2023-06-08 17:19:23 +00:00
self ,
2023-11-06 14:16:36 +00:00
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-06 14:16:36 +00:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_eta : float = 0.1 ,
mirostat_tau : float = 5.0 ,
2023-06-08 17:19:23 +00:00
penalize_nl : bool = True ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-05-01 18:47:55 +00:00
) :
2023-11-06 14:16:36 +00:00
""" Sample a token from the model.
Args :
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
Returns :
The sampled token .
"""
assert self . _ctx is not None
2023-06-29 04:40:47 +00:00
assert self . n_tokens > 0
2023-11-06 14:16:36 +00:00
last_n_tokens_data = [ llama_cpp . llama_token ( 0 ) ] * max (
0 , self . last_n_tokens_size - self . n_tokens
) + self . _input_ids [ - self . last_n_tokens_size : ] . tolist ( )
last_n_tokens_size = len ( last_n_tokens_data )
2023-05-23 21:56:21 +00:00
n_vocab = self . _n_vocab
n_ctx = self . _n_ctx
2023-09-14 01:11:52 +00:00
top_k = n_vocab if top_k < = 0 else top_k
2023-09-29 02:42:03 +00:00
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
2023-11-06 14:16:36 +00:00
last_n_tokens_data_c = ( llama_cpp . llama_token * last_n_tokens_size ) (
* last_n_tokens_data
)
2023-05-26 20:12:45 +00:00
logits : npt . NDArray [ np . single ] = self . _scores [ - 1 , : ]
2023-05-24 19:55:44 +00:00
2023-05-25 18:04:54 +00:00
if logits_processor is not None :
2023-07-18 23:27:41 +00:00
logits [ : ] = logits_processor ( self . _input_ids , logits )
2023-05-25 18:04:54 +00:00
2023-05-21 23:18:56 +00:00
nl_logit = logits [ self . _token_nl ]
2023-11-06 14:16:36 +00:00
self . _candidates . copy_logits ( logits )
self . _ctx . sample_repetition_penalties (
candidates = self . _candidates ,
last_tokens_data = last_n_tokens_data_c ,
2023-10-24 07:13:32 +00:00
penalty_last_n = last_n_tokens_size ,
penalty_repeat = repeat_penalty ,
penalty_freq = frequency_penalty ,
penalty_present = presence_penalty ,
2023-05-09 01:21:25 +00:00
)
2023-05-17 05:53:26 +00:00
if not penalize_nl :
2023-11-06 14:16:36 +00:00
self . _candidates . candidates . data [ self . _token_nl ] . logit = llama_cpp . c_float (
nl_logit
)
2023-08-06 17:21:37 +00:00
2023-08-08 19:08:54 +00:00
if grammar is not None :
2023-11-06 14:16:36 +00:00
self . _ctx . sample_grammar (
candidates = self . _candidates ,
grammar = grammar ,
2023-08-06 17:21:37 +00:00
)
2023-11-21 04:21:33 +00:00
if temp < 0.0 :
self . _ctx . sample_softmax ( candidates = self . _candidates )
id = self . _candidates . candidates . data [ 0 ] . id
elif temp == 0.0 :
2023-11-06 14:16:36 +00:00
id = self . _ctx . sample_token_greedy ( candidates = self . _candidates )
2023-09-14 01:11:52 +00:00
elif mirostat_mode == 1 :
2023-11-06 14:16:36 +00:00
self . _ctx . sample_temp ( candidates = self . _candidates , temp = temp )
id = self . _ctx . sample_token_mirostat (
candidates = self . _candidates ,
2023-05-06 20:47:47 +00:00
tau = mirostat_tau ,
eta = mirostat_eta ,
2023-11-06 14:16:36 +00:00
mu = 2.0 * mirostat_tau ,
m = 100 ,
2023-05-06 20:47:47 +00:00
)
2023-09-29 02:42:03 +00:00
elif mirostat_mode == 2 :
2023-11-06 14:16:36 +00:00
self . _ctx . sample_temp ( candidates = self . _candidates , temp = temp )
id = self . _ctx . sample_token_mirostat_v2 (
candidates = self . _candidates ,
2023-05-06 20:47:47 +00:00
tau = mirostat_tau ,
eta = mirostat_eta ,
2023-11-06 14:16:36 +00:00
mu = 2.0 * mirostat_tau ,
2023-05-01 18:47:55 +00:00
)
else :
2023-11-06 14:16:36 +00:00
self . _ctx . sample_top_k ( candidates = self . _candidates , k = top_k , min_keep = 1 )
self . _ctx . sample_tail_free ( candidates = self . _candidates , z = tfs_z , min_keep = 1 )
2023-11-21 09:02:20 +00:00
self . _ctx . sample_typical (
candidates = self . _candidates , p = typical_p , min_keep = 1
)
2023-11-06 14:16:36 +00:00
self . _ctx . sample_top_p ( candidates = self . _candidates , p = top_p , min_keep = 1 )
2023-11-21 04:21:33 +00:00
self . _ctx . sample_min_p ( candidates = self . _candidates , p = min_p , min_keep = 1 )
2023-11-06 14:16:36 +00:00
self . _ctx . sample_temp ( candidates = self . _candidates , temp = temp )
id = self . _ctx . sample_token ( candidates = self . _candidates )
2023-08-08 19:08:54 +00:00
if grammar is not None :
2023-11-06 14:16:36 +00:00
self . _ctx . grammar_accept_token ( grammar = grammar , token = id )
2023-08-06 17:21:37 +00:00
return id
2023-04-02 04:02:47 +00:00
2023-04-01 17:01:27 +00:00
def generate (
self ,
2023-06-08 17:19:23 +00:00
tokens : Sequence [ int ] ,
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-06-08 17:19:23 +00:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
reset : bool = True ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
2024-01-10 07:46:27 +00:00
penalize_nl : bool = True ,
2023-06-08 17:19:23 +00:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-05-19 15:59:33 +00:00
) - > Generator [ int , Optional [ Sequence [ int ] ] , None ] :
2023-04-02 04:02:47 +00:00
""" Create a generator of tokens from a prompt.
2023-04-01 21:36:30 +00:00
2023-04-01 21:39:35 +00:00
Examples :
>> > llama = Llama ( " models/ggml-7b.bin " )
>> > tokens = llama . tokenize ( b " Hello, world! " )
>> > for token in llama . generate ( tokens , top_k = 40 , top_p = 0.95 , temp = 1.0 , repeat_penalty = 1.1 ) :
. . . print ( llama . detokenize ( [ token ] ) )
2023-04-01 21:36:30 +00:00
Args :
tokens : The prompt tokens .
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
2023-04-13 04:28:00 +00:00
reset : Whether to reset the model state .
2023-04-01 21:36:30 +00:00
Yields :
The generated tokens .
"""
2023-11-06 14:16:36 +00:00
if reset and self . n_tokens > 0 :
2023-05-05 01:58:27 +00:00
longest_prefix = 0
2023-05-27 00:03:31 +00:00
for a , b in zip ( self . _input_ids , tokens [ : - 1 ] ) :
2023-05-05 01:58:27 +00:00
if a == b :
longest_prefix + = 1
else :
break
if longest_prefix > 0 :
if self . verbose :
print ( " Llama.generate: prefix-match hit " , file = sys . stderr )
reset = False
tokens = tokens [ longest_prefix : ]
2023-06-29 04:40:47 +00:00
self . n_tokens = longest_prefix
2023-04-24 23:54:41 +00:00
2023-04-13 04:28:00 +00:00
if reset :
self . reset ( )
2023-05-05 01:58:27 +00:00
2023-08-08 19:08:54 +00:00
if grammar is not None :
grammar . reset ( )
2023-08-07 06:16:25 +00:00
2023-04-01 17:01:27 +00:00
while True :
2023-04-02 04:02:47 +00:00
self . eval ( tokens )
token = self . sample (
top_k = top_k ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-02 04:02:47 +00:00
temp = temp ,
repeat_penalty = repeat_penalty ,
2023-05-09 01:21:25 +00:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-06 20:47:47 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-25 18:04:54 +00:00
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2024-01-10 07:46:27 +00:00
penalize_nl = penalize_nl ,
2023-04-01 17:01:27 +00:00
)
2023-05-25 18:04:54 +00:00
if stopping_criteria is not None and stopping_criteria (
2023-07-18 23:27:41 +00:00
self . _input_ids , self . _scores [ - 1 , : ]
2023-05-25 18:04:54 +00:00
) :
return
2023-04-01 17:01:27 +00:00
tokens_or_none = yield token
tokens = [ token ]
if tokens_or_none is not None :
tokens . extend ( tokens_or_none )
2023-05-19 23:23:32 +00:00
def create_embedding (
2023-06-08 17:19:23 +00:00
self , input : Union [ str , List [ str ] ] , model : Optional [ str ] = None
2023-09-29 02:42:03 +00:00
) - > CreateEmbeddingResponse :
2023-03-28 08:59:54 +00:00
""" Embed a string.
Args :
2023-04-01 17:01:27 +00:00
input : The utf - 8 encoded string to embed .
2023-03-28 08:59:54 +00:00
Returns :
2023-04-01 17:01:27 +00:00
An embedding object .
2023-03-28 08:59:54 +00:00
"""
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
assert self . _model . model is not None
2023-05-16 22:07:25 +00:00
model_name : str = model if model is not None else self . model_path
2023-04-04 17:09:24 +00:00
2023-09-30 17:20:22 +00:00
if self . context_params . embedding == False :
2023-04-05 07:25:37 +00:00
raise RuntimeError (
" Llama model must be created with embedding=True to call this method "
)
2023-04-04 17:09:24 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
llama_cpp . llama_reset_timings ( self . _ctx . ctx )
2023-04-04 17:09:24 +00:00
2023-05-19 23:23:32 +00:00
if isinstance ( input , str ) :
inputs = [ input ]
else :
inputs = input
2023-04-04 17:09:24 +00:00
2023-11-08 03:48:51 +00:00
data : List [ Embedding ] = [ ]
2023-05-19 23:23:32 +00:00
total_tokens = 0
2023-05-22 01:30:03 +00:00
for index , input in enumerate ( inputs ) :
2023-11-02 01:29:06 +00:00
tokens = self . tokenize ( input . encode ( " utf-8 " ) , special = True )
2023-05-19 23:23:32 +00:00
self . reset ( )
self . eval ( tokens )
n_tokens = len ( tokens )
total_tokens + = n_tokens
2023-11-06 14:16:36 +00:00
embedding = llama_cpp . llama_get_embeddings ( self . _ctx . ctx ) [
: llama_cpp . llama_n_embd ( self . _model . model )
2023-06-08 17:19:23 +00:00
]
2023-04-04 17:09:24 +00:00
2023-05-19 23:23:32 +00:00
data . append (
2023-04-01 17:01:27 +00:00
{
" object " : " embedding " ,
" embedding " : embedding ,
2023-05-22 01:30:03 +00:00
" index " : index ,
2023-04-01 17:01:27 +00:00
}
2023-05-19 23:23:32 +00:00
)
2023-05-22 01:30:03 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
llama_cpp . llama_print_timings ( self . _ctx . ctx )
2023-05-19 23:23:32 +00:00
return {
" object " : " list " ,
" data " : data ,
2023-05-22 01:30:03 +00:00
" model " : model_name ,
2023-04-01 17:01:27 +00:00
" usage " : {
2023-05-19 23:23:32 +00:00
" prompt_tokens " : total_tokens ,
" total_tokens " : total_tokens ,
2023-04-01 17:01:27 +00:00
} ,
}
2023-03-28 06:42:22 +00:00
2023-04-03 22:46:19 +00:00
def embed ( self , input : str ) - > List [ float ] :
""" Embed a string.
Args :
input : The utf - 8 encoded string to embed .
Returns :
A list of embeddings
"""
return list ( map ( float , self . create_embedding ( input ) [ " data " ] [ 0 ] [ " embedding " ] ) )
2023-04-01 17:01:27 +00:00
def _create_completion (
2023-03-23 09:33:06 +00:00
self ,
2023-11-08 03:48:51 +00:00
prompt : Union [ str , List [ int ] ] ,
2023-03-23 09:33:06 +00:00
suffix : Optional [ str ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-03-23 09:33:06 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-23 19:51:05 +00:00
logprobs : Optional [ int ] = None ,
2023-03-23 09:33:06 +00:00
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-23 09:33:06 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
2023-03-28 08:03:57 +00:00
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [
Iterator [ CreateCompletionResponse ] , Iterator [ CreateCompletionStreamResponse ]
] :
2023-11-06 14:16:36 +00:00
assert self . _ctx is not None
2023-11-01 22:52:50 +00:00
assert suffix is None or suffix . __class__ is str
2023-05-24 20:02:06 +00:00
2023-04-15 15:39:21 +00:00
completion_id : str = f " cmpl- { str ( uuid . uuid4 ( ) ) } "
created : int = int ( time . time ( ) )
2023-11-21 03:50:59 +00:00
# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
completion_tokens : List [ int ] = [ ] if len ( prompt ) > 0 else [ self . token_bos ( ) ]
2023-04-01 17:01:27 +00:00
# Add blank space to start of prompt to match OG llama tokenizer
2023-09-29 02:42:03 +00:00
prompt_tokens : List [ int ] = (
2023-11-08 16:09:41 +00:00
(
self . tokenize ( prompt . encode ( " utf-8 " ) , special = True )
if prompt != " "
else [ self . token_bos ( ) ]
)
if isinstance ( prompt , str )
else prompt
)
2023-04-15 15:39:21 +00:00
text : bytes = b " "
2023-05-18 15:35:59 +00:00
returned_tokens : int = 0
2023-05-19 15:59:33 +00:00
stop = (
stop if isinstance ( stop , list ) else [ stop ] if isinstance ( stop , str ) else [ ]
)
2023-05-16 22:07:25 +00:00
model_name : str = model if model is not None else self . model_path
2023-03-23 09:33:06 +00:00
2023-11-21 08:59:46 +00:00
# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None :
logit_bias_map = { int ( k ) : float ( v ) for k , v in logit_bias . items ( ) }
def logit_bias_processor (
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
) - > npt . NDArray [ np . single ] :
new_scores = np . copy (
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id , score in logit_bias_map . items ( ) :
new_scores [ input_id ] = score + scores [ input_id ]
return new_scores
_logit_bias_processor = LogitsProcessorList ( [ logit_bias_processor ] )
if logits_processor is None :
logits_processor = _logit_bias_processor
else :
logits_processor = logits_processor . extend ( _logit_bias_processor )
2023-04-04 17:09:24 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
self . _ctx . reset_timings ( )
2023-04-04 17:09:24 +00:00
2023-11-06 14:16:36 +00:00
if len ( prompt_tokens ) > = self . _n_ctx :
2023-03-23 09:33:06 +00:00
raise ValueError (
2023-11-08 03:48:51 +00:00
f " Requested tokens ( { len ( prompt_tokens ) } ) exceed context window of { llama_cpp . llama_n_ctx ( self . ctx ) } "
2023-03-23 09:33:06 +00:00
)
2023-11-10 07:49:27 +00:00
if max_tokens is None or max_tokens < = 0 :
2023-07-09 22:13:29 +00:00
# Unlimited, depending on n_ctx.
2023-11-06 14:16:36 +00:00
max_tokens = self . _n_ctx - len ( prompt_tokens )
2023-07-09 22:13:29 +00:00
2023-06-09 14:57:36 +00:00
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (
max_tokens
if max_tokens + len ( prompt_tokens ) < self . _n_ctx
else ( self . _n_ctx - len ( prompt_tokens ) )
)
2023-04-01 17:01:27 +00:00
if stop != [ ] :
2023-04-02 07:59:19 +00:00
stop_sequences = [ s . encode ( " utf-8 " ) for s in stop ]
2023-04-01 17:01:27 +00:00
else :
2023-04-02 07:59:19 +00:00
stop_sequences = [ ]
2023-03-24 18:33:38 +00:00
2023-09-30 20:02:35 +00:00
if logprobs is not None and self . context_params . logits_all is False :
2023-04-12 18:05:11 +00:00
raise ValueError (
" logprobs is not supported for models created with logits_all=False "
)
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-07 23:31:26 +00:00
try :
cache_item = self . cache [ prompt_tokens ]
cache_prefix_len = Llama . longest_token_prefix (
2023-05-27 00:12:05 +00:00
cache_item . input_ids . tolist ( ) , prompt_tokens
2023-05-07 23:31:26 +00:00
)
eval_prefix_len = Llama . longest_token_prefix (
2023-05-27 00:12:05 +00:00
self . _input_ids . tolist ( ) , prompt_tokens
2023-05-07 23:31:26 +00:00
)
if cache_prefix_len > eval_prefix_len :
self . load_state ( cache_item )
if self . verbose :
print ( " Llama._create_completion: cache hit " , file = sys . stderr )
except KeyError :
if self . verbose :
print ( " Llama._create_completion: cache miss " , file = sys . stderr )
2023-11-08 16:09:41 +00:00
2023-11-08 04:37:28 +00:00
if seed is not None :
self . _ctx . set_rng_seed ( seed )
2023-04-15 16:03:09 +00:00
2023-04-12 18:05:11 +00:00
finish_reason = " length "
2023-04-28 10:50:30 +00:00
multibyte_fix = 0
2023-04-01 17:01:27 +00:00
for token in self . generate (
prompt_tokens ,
top_k = top_k ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 17:01:27 +00:00
temp = temperature ,
2023-06-08 17:19:23 +00:00
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 17:01:27 +00:00
repeat_penalty = repeat_penalty ,
2023-06-08 17:19:23 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-03-28 08:03:57 +00:00
) :
2023-05-21 23:18:56 +00:00
if token == self . _token_eos :
2023-04-02 07:59:19 +00:00
text = self . detokenize ( completion_tokens )
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-04-24 23:54:41 +00:00
2023-03-28 05:45:37 +00:00
completion_tokens . append ( token )
2023-03-23 09:33:06 +00:00
2023-04-02 07:59:19 +00:00
all_text = self . detokenize ( completion_tokens )
2023-04-28 11:16:18 +00:00
# Contains multi-byte UTF8
2023-05-01 18:47:55 +00:00
for k , char in enumerate ( all_text [ - 3 : ] ) :
2023-04-28 11:16:18 +00:00
k = 3 - k
2023-05-01 18:47:55 +00:00
for num , pattern in [ ( 2 , 192 ) , ( 3 , 224 ) , ( 4 , 240 ) ] :
2023-04-28 11:16:18 +00:00
# Bitwise AND check
2023-05-01 18:47:55 +00:00
if num > k and pattern & char == pattern :
2023-04-28 11:16:18 +00:00
multibyte_fix = num - k
2023-04-28 10:50:30 +00:00
# Stop incomplete bytes from passing
2023-05-01 18:47:55 +00:00
if multibyte_fix > 0 :
2023-04-28 10:50:30 +00:00
multibyte_fix - = 1
continue
2023-04-02 07:59:19 +00:00
any_stop = [ s for s in stop_sequences if s in all_text ]
2023-03-23 09:33:06 +00:00
if len ( any_stop ) > 0 :
first_stop = any_stop [ 0 ]
2023-04-02 07:59:19 +00:00
text = all_text [ : all_text . index ( first_stop ) ]
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-03-28 08:03:57 +00:00
if stream :
2023-05-27 00:23:49 +00:00
remaining_tokens = completion_tokens [ returned_tokens : ]
remaining_text = self . detokenize ( remaining_tokens )
remaining_length = len ( remaining_text )
2023-04-02 07:59:19 +00:00
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
2023-05-19 06:20:27 +00:00
first_stop_position = 0
2023-04-02 07:59:19 +00:00
for s in stop_sequences :
2023-05-27 00:23:49 +00:00
for i in range ( min ( len ( s ) , remaining_length ) , 0 , - 1 ) :
if remaining_text . endswith ( s [ : i ] ) :
2023-05-19 06:20:27 +00:00
if i > first_stop_position :
first_stop_position = i
2023-03-28 08:03:57 +00:00
break
2023-05-18 15:35:59 +00:00
2023-05-19 06:20:27 +00:00
token_end_position = 0
2023-08-09 14:04:35 +00:00
if logprobs is not None :
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2023-08-09 14:04:35 +00:00
token_end_position + = len ( self . detokenize ( [ token ] ) )
# Check if stop sequence is in the token
2023-09-29 02:42:03 +00:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 14:04:35 +00:00
break
2023-05-19 06:20:27 +00:00
token_str = self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
2023-12-22 05:03:29 +00:00
self . detokenize ( completion_tokens [ : returned_tokens ] ) . decode (
" utf-8 " , errors = " ignore "
)
2023-05-19 06:20:27 +00:00
)
token_offset = len ( prompt_tokens ) + returned_tokens
2023-12-18 19:28:12 +00:00
logits = self . _scores [ token_offset - 1 , : ]
2023-12-18 23:40:36 +00:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 06:20:27 +00:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 15:59:33 +00:00
self . detokenize ( [ i ] ) . decode (
2023-05-19 06:20:27 +00:00
" utf-8 " , errors = " ignore "
) : logprob
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
" tokens " : [
self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
)
] ,
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 06:20:27 +00:00
" top_logprobs " : [ top_logprob ] ,
}
2023-08-09 14:04:35 +00:00
returned_tokens + = 1
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
) ,
" index " : 0 ,
" logprobs " : logprobs_or_none ,
" finish_reason " : None ,
}
] ,
}
else :
while len ( remaining_tokens ) > 0 :
decode_success = False
for i in range ( 1 , len ( remaining_tokens ) + 1 ) :
try :
2023-08-29 11:21:59 +00:00
bs = self . detokenize ( remaining_tokens [ : i ] )
2023-09-29 02:42:03 +00:00
ts = bs . decode ( " utf-8 " )
2023-08-09 14:04:35 +00:00
decode_success = True
break
except UnicodeError :
pass
2023-08-29 11:21:59 +00:00
else :
break
2023-08-09 14:04:35 +00:00
if not decode_success :
# all remaining tokens cannot be decoded to a UTF-8 character
break
token_end_position + = len ( bs )
2023-09-29 02:42:03 +00:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 14:04:35 +00:00
break
remaining_tokens = remaining_tokens [ i : ]
returned_tokens + = i
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
2023-08-29 11:21:59 +00:00
" text " : ts ,
2023-08-09 14:04:35 +00:00
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
}
2023-04-12 18:05:11 +00:00
2023-04-02 07:59:19 +00:00
if len ( completion_tokens ) > = max_tokens :
text = self . detokenize ( completion_tokens )
finish_reason = " length "
break
2023-03-23 09:33:06 +00:00
2023-05-26 07:13:24 +00:00
if stopping_criteria is not None and stopping_criteria (
2023-07-18 23:27:41 +00:00
self . _input_ids , self . _scores [ - 1 , : ]
2023-05-26 07:13:24 +00:00
) :
2023-05-26 14:25:28 +00:00
text = self . detokenize ( completion_tokens )
2023-05-26 07:13:24 +00:00
finish_reason = " stop "
2023-05-10 20:12:17 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
self . _ctx . print_timings ( )
2023-05-10 20:12:17 +00:00
2023-03-28 08:03:57 +00:00
if stream :
2023-05-18 15:35:59 +00:00
remaining_tokens = completion_tokens [ returned_tokens : ]
all_text = self . detokenize ( remaining_tokens )
any_stop = [ s for s in stop_sequences if s in all_text ]
if len ( any_stop ) > 0 :
end = min ( all_text . index ( stop ) for stop in any_stop )
else :
end = len ( all_text )
2023-05-19 06:20:27 +00:00
token_end_position = 0
2023-05-18 15:35:59 +00:00
for token in remaining_tokens :
2023-05-19 06:20:27 +00:00
token_end_position + = len ( self . detokenize ( [ token ] ) )
logprobs_or_none : Optional [ CompletionLogprobs ] = None
if logprobs is not None :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2023-05-19 06:20:27 +00:00
token_str = self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
self . detokenize ( completion_tokens [ : returned_tokens ] )
)
token_offset = len ( prompt_tokens ) + returned_tokens - 1
2023-12-18 19:28:12 +00:00
logits = self . _scores [ token_offset , : ]
2023-12-18 23:40:36 +00:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 06:20:27 +00:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 15:59:33 +00:00
self . detokenize ( [ i ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-05-19 06:20:27 +00:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
" tokens " : [
self . detokenize ( [ token ] ) . decode ( " utf-8 " , errors = " ignore " )
] ,
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 06:20:27 +00:00
" top_logprobs " : [ top_logprob ] ,
}
if token_end_position > = end :
2023-05-18 15:35:59 +00:00
last_text = self . detokenize ( [ token ] )
2023-05-19 06:20:27 +00:00
if token_end_position == end - 1 :
2023-05-18 15:35:59 +00:00
break
2023-05-19 06:20:27 +00:00
returned_tokens + = 1
2023-05-18 15:35:59 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : last_text [
2023-06-08 17:19:23 +00:00
: len ( last_text ) - ( token_end_position - end )
] . decode ( " utf-8 " , errors = " ignore " ) ,
2023-05-18 15:35:59 +00:00
" index " : 0 ,
2023-05-19 06:20:27 +00:00
" logprobs " : logprobs_or_none ,
2023-07-08 04:06:11 +00:00
" finish_reason " : None ,
}
] ,
}
2023-05-18 15:35:59 +00:00
break
returned_tokens + = 1
2023-03-28 08:03:57 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
2023-05-18 15:35:59 +00:00
" model " : model_name ,
2023-03-28 08:03:57 +00:00
" choices " : [
{
2023-05-18 15:35:59 +00:00
" text " : self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
) ,
2023-03-28 08:03:57 +00:00
" index " : 0 ,
2023-05-19 06:20:27 +00:00
" logprobs " : logprobs_or_none ,
2023-03-28 08:03:57 +00:00
" finish_reason " : None ,
}
] ,
}
2023-10-19 06:55:56 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : " " ,
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : finish_reason ,
}
] ,
}
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-26 07:03:01 +00:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-06-08 17:19:23 +00:00
print ( " Llama._create_completion: cache saved " , file = sys . stderr )
2023-03-28 08:03:57 +00:00
return
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-26 07:03:01 +00:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-04-26 12:37:06 +00:00
text_str = text . decode ( " utf-8 " , errors = " ignore " )
2023-03-23 20:25:13 +00:00
2023-03-23 09:33:06 +00:00
if echo :
2023-04-15 16:03:09 +00:00
text_str = prompt + text_str
2023-03-23 09:33:06 +00:00
if suffix is not None :
2023-04-15 16:03:09 +00:00
text_str = text_str + suffix
2023-03-23 09:33:06 +00:00
2023-04-12 18:05:11 +00:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
2023-03-23 19:51:05 +00:00
if logprobs is not None :
2023-05-19 06:20:27 +00:00
text_offset = 0 if echo else len ( prompt )
token_offset = 0 if echo else len ( prompt_tokens [ 1 : ] )
2023-04-14 13:59:33 +00:00
text_offsets : List [ int ] = [ ]
2023-05-19 06:20:27 +00:00
token_logprobs : List [ Optional [ float ] ] = [ ]
2023-04-14 13:59:33 +00:00
tokens : List [ str ] = [ ]
2023-05-19 06:20:27 +00:00
top_logprobs : List [ Optional [ Dict [ str , float ] ] ] = [ ]
if echo :
# Remove leading BOS token
all_tokens = prompt_tokens [ 1 : ] + completion_tokens
else :
all_tokens = completion_tokens
2023-04-14 13:59:33 +00:00
all_token_strs = [
2023-05-01 18:47:55 +00:00
self . detokenize ( [ token ] ) . decode ( " utf-8 " , errors = " ignore " )
for token in all_tokens
2023-04-14 13:59:33 +00:00
]
2023-12-18 19:28:12 +00:00
all_logprobs = Llama . logits_to_logprobs ( self . _scores ) [ token_offset : ]
# TODO: may be able to change this loop to use np.take_along_dim
2023-12-22 05:03:29 +00:00
for idx , ( token , token_str , logprobs_token ) in enumerate (
zip ( all_tokens , all_token_strs , all_logprobs )
2023-04-14 13:59:33 +00:00
) :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2023-12-22 05:03:29 +00:00
text_offsets . append (
text_offset
+ len (
self . detokenize ( all_tokens [ : idx ] ) . decode (
" utf-8 " , errors = " ignore "
)
)
)
2023-04-14 13:59:33 +00:00
tokens . append ( token_str )
sorted_logprobs = list (
sorted (
zip ( logprobs_token , range ( len ( logprobs_token ) ) ) , reverse = True
)
)
2023-07-07 10:18:49 +00:00
token_logprobs . append ( logprobs_token [ int ( token ) ] )
2023-05-19 06:20:27 +00:00
top_logprob : Optional [ Dict [ str , float ] ] = {
2023-05-19 15:59:33 +00:00
self . detokenize ( [ i ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-04-14 13:59:33 +00:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
2023-05-19 06:20:27 +00:00
top_logprob . update ( { token_str : logprobs_token [ int ( token ) ] } )
2023-04-14 13:59:33 +00:00
top_logprobs . append ( top_logprob )
2023-05-19 06:20:27 +00:00
# Weird idosincracy of the OpenAI API where
# token_logprobs and top_logprobs are null for
# the first token.
if echo and len ( all_tokens ) > 0 :
token_logprobs [ 0 ] = None
top_logprobs [ 0 ] = None
2023-04-12 18:05:11 +00:00
logprobs_or_none = {
" tokens " : tokens ,
" text_offset " : text_offsets ,
" token_logprobs " : token_logprobs ,
" top_logprobs " : top_logprobs ,
}
2023-04-04 17:09:24 +00:00
2023-03-28 08:03:57 +00:00
yield {
2023-03-28 06:42:22 +00:00
" id " : completion_id ,
2023-03-23 09:33:06 +00:00
" object " : " text_completion " ,
2023-03-28 06:42:22 +00:00
" created " : created ,
2023-05-16 22:07:25 +00:00
" model " : model_name ,
2023-03-23 09:33:06 +00:00
" choices " : [
{
2023-04-15 16:03:09 +00:00
" text " : text_str ,
2023-03-23 09:33:06 +00:00
" index " : 0 ,
2023-04-12 18:05:11 +00:00
" logprobs " : logprobs_or_none ,
2023-03-23 09:33:06 +00:00
" finish_reason " : finish_reason ,
}
] ,
" usage " : {
2023-03-28 05:45:37 +00:00
" prompt_tokens " : len ( prompt_tokens ) ,
" completion_tokens " : len ( completion_tokens ) ,
" total_tokens " : len ( prompt_tokens ) + len ( completion_tokens ) ,
2023-03-23 09:33:06 +00:00
} ,
}
2023-04-01 17:01:27 +00:00
def create_completion (
self ,
2023-11-08 03:48:51 +00:00
prompt : Union [ str , List [ int ] ] ,
2023-04-01 17:01:27 +00:00
suffix : Optional [ str ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-04-01 17:01:27 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-01 17:01:27 +00:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-04-01 17:01:27 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-04-01 17:01:27 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-10 07:49:27 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-04-01 17:01:27 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-01 17:01:27 +00:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-04-01 17:01:27 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-04-01 17:01:27 +00:00
stream : Whether to stream the results .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-04-01 17:01:27 +00:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
completion_or_chunks = self . _create_completion (
prompt = prompt ,
suffix = suffix ,
2023-12-22 19:05:13 +00:00
max_tokens = - 1 if max_tokens is None else max_tokens ,
2023-04-01 17:01:27 +00:00
temperature = temperature ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 17:01:27 +00:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 05:30:18 +00:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 17:01:27 +00:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-05-26 07:13:24 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-09-29 02:42:03 +00:00
grammar = grammar ,
2023-11-21 09:01:36 +00:00
logit_bias = logit_bias ,
2023-04-01 17:01:27 +00:00
)
if stream :
2023-11-08 03:48:51 +00:00
chunks : Iterator [ CreateCompletionStreamResponse ] = completion_or_chunks
2023-04-01 17:01:27 +00:00
return chunks
completion : Completion = next ( completion_or_chunks ) # type: ignore
return completion
2023-03-28 08:03:57 +00:00
def __call__ (
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
2023-12-22 19:05:13 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-03-28 08:03:57 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-28 08:03:57 +00:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-28 08:03:57 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-03-28 08:03:57 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-24 08:24:19 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-03-28 08:03:57 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-03-28 08:03:57 +00:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-03-28 08:03:57 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-03-28 08:03:57 +00:00
stream : Whether to stream the results .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-03-28 08:03:57 +00:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
2023-04-01 17:01:27 +00:00
return self . create_completion (
2023-03-28 08:03:57 +00:00
prompt = prompt ,
suffix = suffix ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-03-28 08:03:57 +00:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 05:30:18 +00:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-03-28 08:03:57 +00:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-05-26 07:13:24 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-11-21 09:01:36 +00:00
logit_bias = logit_bias ,
2023-03-28 08:03:57 +00:00
)
2023-04-04 00:12:44 +00:00
def create_chat_completion (
self ,
2023-09-29 23:52:04 +00:00
messages : List [ ChatCompletionRequestMessage ] ,
2023-07-19 07:48:20 +00:00
functions : Optional [ List [ ChatCompletionFunction ] ] = None ,
2023-11-08 03:48:51 +00:00
function_call : Optional [ ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ ChatCompletionTool ] ] = None ,
tool_choice : Optional [ ChatCompletionToolChoiceOption ] = None ,
2023-06-08 17:19:23 +00:00
temperature : float = 0.2 ,
2023-04-04 00:12:44 +00:00
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-04 00:12:44 +00:00
stream : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-11-08 05:07:16 +00:00
response_format : Optional [ ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
2023-04-04 00:12:44 +00:00
repeat_penalty : float = 1.1 ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
2023-06-09 17:13:08 +00:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 08:59:46 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [
CreateChatCompletionResponse , Iterator [ CreateChatCompletionStreamResponse ]
] :
2023-04-04 00:24:20 +00:00
""" Generate a chat completion from a list of messages.
Args :
messages : A list of messages to generate a response for .
2023-11-24 08:24:19 +00:00
functions : A list of functions to use for the chat completion .
function_call : A function call to use for the chat completion .
tools : A list of tools to use for the chat completion .
tool_choice : A tool choice to use for the chat completion .
2023-04-04 00:24:20 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-04 00:24:20 +00:00
stream : Whether to stream the results .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
response_format : The response format to use for the chat completion . Use { " type " : " json_object " } to contstrain output to only valid json .
2023-11-10 07:49:27 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-11-24 08:24:19 +00:00
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
2023-04-04 00:24:20 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
tfs_z : The tail - free sampling parameter .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The mirostat sampling tau parameter .
mirostat_eta : The mirostat sampling eta parameter .
model : The name to use for the model in the completion object .
logits_processor : A list of logits processors to use .
grammar : A grammar to use .
logit_bias : A logit bias to use .
2023-04-04 00:24:20 +00:00
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
2023-11-08 03:48:51 +00:00
handler = self . chat_handler or llama_chat_format . get_chat_completion_handler (
self . chat_format
)
2023-11-03 06:12:14 +00:00
return handler (
2023-11-08 03:48:51 +00:00
llama = self ,
2023-09-29 23:52:04 +00:00
messages = messages ,
2023-11-03 06:12:14 +00:00
functions = functions ,
function_call = function_call ,
2023-11-08 03:48:51 +00:00
tools = tools ,
tool_choice = tool_choice ,
2023-04-04 00:12:44 +00:00
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-04 00:12:44 +00:00
stream = stream ,
2023-09-29 23:52:04 +00:00
stop = stop ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-11-08 05:07:16 +00:00
response_format = response_format ,
2023-04-04 00:12:44 +00:00
max_tokens = max_tokens ,
2023-05-08 05:30:18 +00:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
2023-09-29 23:52:04 +00:00
repeat_penalty = repeat_penalty ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-06-09 17:13:08 +00:00
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-11-21 08:59:46 +00:00
logit_bias = logit_bias ,
2023-04-04 00:12:44 +00:00
)
2023-04-05 10:52:17 +00:00
def __getstate__ ( self ) :
return dict (
model_path = self . model_path ,
2023-09-29 02:42:03 +00:00
# Model Params
n_gpu_layers = self . model_params . n_gpu_layers ,
2024-01-15 17:49:20 +00:00
split_mode = self . model_params . split_mode ,
2023-09-29 02:42:03 +00:00
main_gpu = self . model_params . main_gpu ,
tensor_split = self . tensor_split ,
vocab_only = self . model_params . vocab_only ,
use_mmap = self . model_params . use_mmap ,
use_mlock = self . model_params . use_mlock ,
2024-01-15 17:29:29 +00:00
kv_overrides = self . kv_overrides ,
2023-09-29 02:42:03 +00:00
# Context Params
seed = self . context_params . seed ,
n_ctx = self . context_params . n_ctx ,
2023-04-05 10:52:17 +00:00
n_batch = self . n_batch ,
2023-09-29 02:42:03 +00:00
n_threads = self . context_params . n_threads ,
n_threads_batch = self . context_params . n_threads_batch ,
2023-11-02 17:40:20 +00:00
rope_scaling_type = self . context_params . rope_scaling_type ,
2023-09-29 02:42:03 +00:00
rope_freq_base = self . context_params . rope_freq_base ,
rope_freq_scale = self . context_params . rope_freq_scale ,
2023-11-02 17:40:20 +00:00
yarn_ext_factor = self . context_params . yarn_ext_factor ,
yarn_attn_factor = self . context_params . yarn_attn_factor ,
yarn_beta_fast = self . context_params . yarn_beta_fast ,
yarn_beta_slow = self . context_params . yarn_beta_slow ,
yarn_orig_ctx = self . context_params . yarn_orig_ctx ,
2023-09-29 02:42:03 +00:00
mul_mat_q = self . context_params . mul_mat_q ,
logits_all = self . context_params . logits_all ,
embedding = self . context_params . embedding ,
# Sampling Params
last_n_tokens_size = self . last_n_tokens_size ,
# LoRA Params
2023-04-18 14:20:46 +00:00
lora_base = self . lora_base ,
2023-09-29 02:42:03 +00:00
lora_scale = self . lora_scale ,
2023-04-18 05:43:44 +00:00
lora_path = self . lora_path ,
2023-09-29 02:42:03 +00:00
# Backend Params
numa = self . numa ,
2023-09-29 23:52:04 +00:00
# Chat Format Params
chat_format = self . chat_format ,
2023-11-08 03:48:51 +00:00
chat_handler = self . chat_handler ,
2023-09-29 02:42:03 +00:00
# Misc
verbose = self . verbose ,
2023-04-05 10:52:17 +00:00
)
def __setstate__ ( self , state ) :
self . __init__ (
model_path = state [ " model_path " ] ,
2023-09-29 02:42:03 +00:00
# Model Params
2023-05-14 04:04:22 +00:00
n_gpu_layers = state [ " n_gpu_layers " ] ,
2024-01-15 17:49:20 +00:00
split_mode = state [ " split_mode " ] ,
2023-09-29 02:42:03 +00:00
main_gpu = state [ " main_gpu " ] ,
tensor_split = state [ " tensor_split " ] ,
2023-04-05 10:52:17 +00:00
vocab_only = state [ " vocab_only " ] ,
2023-04-10 06:11:35 +00:00
use_mmap = state [ " use_mmap " ] ,
2023-04-05 10:52:17 +00:00
use_mlock = state [ " use_mlock " ] ,
2024-01-15 17:29:29 +00:00
kv_overrides = state [ " kv_overrides " ] ,
2023-09-29 02:42:03 +00:00
# Context Params
seed = state [ " seed " ] ,
n_ctx = state [ " n_ctx " ] ,
2023-04-05 10:52:17 +00:00
n_batch = state [ " n_batch " ] ,
2023-09-29 02:42:03 +00:00
n_threads = state [ " n_threads " ] ,
n_threads_batch = state [ " n_threads_batch " ] ,
rope_freq_base = state [ " rope_freq_base " ] ,
rope_freq_scale = state [ " rope_freq_scale " ] ,
2023-11-02 17:40:20 +00:00
rope_scaling_type = state [ " rope_scaling_type " ] ,
yarn_ext_factor = state [ " yarn_ext_factor " ] ,
yarn_attn_factor = state [ " yarn_attn_factor " ] ,
yarn_beta_fast = state [ " yarn_beta_fast " ] ,
yarn_beta_slow = state [ " yarn_beta_slow " ] ,
yarn_orig_ctx = state [ " yarn_orig_ctx " ] ,
2023-09-29 02:42:03 +00:00
mul_mat_q = state [ " mul_mat_q " ] ,
logits_all = state [ " logits_all " ] ,
embedding = state [ " embedding " ] ,
# Sampling Params
2023-04-05 10:52:17 +00:00
last_n_tokens_size = state [ " last_n_tokens_size " ] ,
2023-09-29 02:42:03 +00:00
# LoRA Params
2023-04-18 14:20:46 +00:00
lora_base = state [ " lora_base " ] ,
2023-04-18 05:43:44 +00:00
lora_path = state [ " lora_path " ] ,
2023-09-29 02:42:03 +00:00
# Backend Params
numa = state [ " numa " ] ,
2023-09-29 23:52:04 +00:00
# Chat Format Params
chat_format = state [ " chat_format " ] ,
2023-11-08 03:48:51 +00:00
chat_handler = state [ " chat_handler " ] ,
2023-09-29 02:42:03 +00:00
# Misc
2023-04-05 10:52:17 +00:00
verbose = state [ " verbose " ] ,
)
2023-04-24 21:51:25 +00:00
def save_state ( self ) - > LlamaState :
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( " Llama.save_state: saving llama state " , file = sys . stderr )
2023-11-06 14:16:36 +00:00
state_size = llama_cpp . llama_get_state_size ( self . _ctx . ctx )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( f " Llama.save_state: got state size: { state_size } " , file = sys . stderr )
2023-04-24 21:51:25 +00:00
llama_state = ( llama_cpp . c_uint8 * int ( state_size ) ) ( )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( " Llama.save_state: allocated state " , file = sys . stderr )
2023-11-06 14:16:36 +00:00
n_bytes = llama_cpp . llama_copy_state_data ( self . _ctx . ctx , llama_state )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( f " Llama.save_state: copied llama state: { n_bytes } " , file = sys . stderr )
2023-05-03 13:33:50 +00:00
if int ( n_bytes ) > int ( state_size ) :
2023-04-24 21:51:25 +00:00
raise RuntimeError ( " Failed to copy llama state data " )
2023-05-03 13:33:50 +00:00
llama_state_compact = ( llama_cpp . c_uint8 * int ( n_bytes ) ) ( )
llama_cpp . ctypes . memmove ( llama_state_compact , llama_state , int ( n_bytes ) )
2023-05-03 14:28:10 +00:00
if self . verbose :
2023-05-05 01:58:36 +00:00
print (
f " Llama.save_state: saving { n_bytes } bytes of llama state " ,
file = sys . stderr ,
)
2023-04-24 21:51:25 +00:00
return LlamaState (
2023-06-29 04:40:47 +00:00
scores = self . scores . copy ( ) ,
input_ids = self . input_ids . copy ( ) ,
n_tokens = self . n_tokens ,
2023-06-13 10:03:31 +00:00
llama_state = bytes ( llama_state_compact ) ,
2023-05-03 13:33:50 +00:00
llama_state_size = n_bytes ,
2023-04-24 21:51:25 +00:00
)
def load_state ( self , state : LlamaState ) - > None :
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
2023-06-29 04:40:47 +00:00
self . scores = state . scores . copy ( )
self . input_ids = state . input_ids . copy ( )
self . n_tokens = state . n_tokens
2023-05-03 13:33:50 +00:00
state_size = state . llama_state_size
2023-06-29 04:40:47 +00:00
LLamaStateArrayType = llama_cpp . c_uint8 * state_size
2023-06-13 10:03:31 +00:00
llama_state = LLamaStateArrayType . from_buffer_copy ( state . llama_state )
2023-11-06 14:16:36 +00:00
if llama_cpp . llama_set_state_data ( self . _ctx . ctx , llama_state ) != state_size :
2023-04-24 21:51:25 +00:00
raise RuntimeError ( " Failed to set llama state data " )
2023-05-20 12:13:41 +00:00
def n_ctx ( self ) - > int :
""" Return the context window size. """
2023-11-06 14:16:36 +00:00
return self . _ctx . n_ctx ( )
2023-05-20 12:13:41 +00:00
def n_embd ( self ) - > int :
""" Return the embedding size. """
2023-11-06 14:16:36 +00:00
return self . _model . n_embd ( )
2023-05-20 12:13:41 +00:00
def n_vocab ( self ) - > int :
""" Return the vocabulary size. """
2023-11-06 14:16:36 +00:00
return self . _model . n_vocab ( )
2023-05-20 12:13:41 +00:00
2023-05-25 18:11:33 +00:00
def tokenizer ( self ) - > " LlamaTokenizer " :
""" Return the tokenizer for this model. """
return LlamaTokenizer ( self )
2023-04-05 10:52:17 +00:00
2023-08-24 04:17:00 +00:00
def token_eos ( self ) - > int :
2023-04-01 21:29:30 +00:00
""" Return the end-of-sequence token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_eos ( )
2023-04-01 21:29:30 +00:00
2023-08-24 04:17:00 +00:00
def token_bos ( self ) - > int :
2023-04-01 21:29:30 +00:00
""" Return the beginning-of-sequence token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_bos ( )
2023-04-12 18:05:11 +00:00
2023-08-24 04:17:00 +00:00
def token_nl ( self ) - > int :
2023-05-17 05:53:26 +00:00
""" Return the newline token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_nl ( )
2023-05-17 05:53:26 +00:00
2023-04-12 18:05:11 +00:00
@staticmethod
2023-12-16 23:59:26 +00:00
def logits_to_logprobs (
2023-12-18 19:28:12 +00:00
logits : Union [ npt . NDArray [ np . single ] , List ] , axis : int = - 1
2023-12-16 23:59:26 +00:00
) - > npt . NDArray [ np . single ] :
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
logits_maxs : np . ndarray = np . amax ( logits , axis = axis , keepdims = True )
if logits_maxs . ndim > 0 :
logits_maxs [ ~ np . isfinite ( logits_maxs ) ] = 0
elif not np . isfinite ( logits_maxs ) :
logits_maxs = 0
subtract_maxs = np . subtract ( logits , logits_maxs , dtype = np . single )
exp = np . exp ( subtract_maxs )
# Suppress warnings about log of zero
2023-12-18 19:28:12 +00:00
with np . errstate ( divide = " ignore " ) :
2023-12-16 23:59:26 +00:00
summed = np . sum ( exp , axis = axis , keepdims = True )
out = np . log ( summed )
return subtract_maxs - out
2023-05-07 23:31:26 +00:00
@staticmethod
2023-05-19 15:59:33 +00:00
def longest_token_prefix ( a : Sequence [ int ] , b : Sequence [ int ] ) :
2023-05-07 23:31:26 +00:00
longest_prefix = 0
for _a , _b in zip ( a , b ) :
if _a == _b :
longest_prefix + = 1
else :
break
return longest_prefix
2023-05-25 18:11:33 +00:00
class LlamaTokenizer :
def __init__ ( self , llama : Llama ) :
self . llama = llama
2023-05-26 07:00:51 +00:00
def encode ( self , text : str , add_bos : bool = True ) - > List [ int ] :
return self . llama . tokenize (
2023-11-02 01:29:06 +00:00
text . encode ( " utf-8 " , errors = " ignore " ) , add_bos = add_bos , special = True
2023-05-26 07:00:51 +00:00
)
2023-05-25 18:11:33 +00:00
def decode ( self , tokens : List [ int ] ) - > str :
return self . llama . detokenize ( tokens ) . decode ( " utf-8 " , errors = " ignore " )
@classmethod
def from_ggml_file ( cls , path : str ) - > " LlamaTokenizer " :
return cls ( Llama ( model_path = path , vocab_only = True ) )