2024-01-17 14:16:13 +00:00
from __future__ import annotations
2023-03-24 19:47:17 +00:00
import os
2023-04-04 17:09:24 +00:00
import sys
2023-03-23 09:33:06 +00:00
import uuid
import time
2024-02-21 21:25:10 +00:00
import json
2024-02-23 16:24:53 +00:00
import ctypes
2024-02-21 21:25:10 +00:00
import fnmatch
2023-03-23 09:33:06 +00:00
import multiprocessing
2024-02-23 16:24:53 +00:00
2023-05-25 18:04:54 +00:00
from typing import (
List ,
Optional ,
Union ,
Generator ,
Sequence ,
Iterator ,
Deque ,
Callable ,
2024-04-17 14:06:50 +00:00
Dict ,
2023-05-25 18:04:54 +00:00
)
2024-01-17 14:09:12 +00:00
from collections import deque
2024-02-21 21:25:10 +00:00
from pathlib import Path
2023-03-23 09:33:06 +00:00
2024-02-08 01:07:03 +00:00
from llama_cpp . llama_types import List
2023-04-01 17:01:27 +00:00
from . llama_types import *
2023-08-06 17:21:37 +00:00
from . llama_grammar import LlamaGrammar
2024-01-17 14:09:12 +00:00
from . llama_cache import (
BaseLlamaCache ,
LlamaCache , # type: ignore
LlamaDiskCache , # type: ignore
LlamaRAMCache , # type: ignore
)
2024-02-21 21:25:10 +00:00
from . llama_tokenizer import BaseLlamaTokenizer , LlamaTokenizer
2023-11-08 03:48:51 +00:00
import llama_cpp . llama_cpp as llama_cpp
2023-11-03 06:12:14 +00:00
import llama_cpp . llama_chat_format as llama_chat_format
2023-03-23 09:33:06 +00:00
2024-01-31 19:08:14 +00:00
from llama_cpp . llama_speculative import LlamaDraftModel
2023-05-26 20:12:45 +00:00
import numpy as np
import numpy . typing as npt
2024-01-17 14:14:00 +00:00
from . _internals import (
_LlamaModel , # type: ignore
_LlamaContext , # type: ignore
_LlamaBatch , # type: ignore
_LlamaTokenDataArray , # type: ignore
2024-01-31 19:08:14 +00:00
_LlamaSamplingParams , # type: ignore
_LlamaSamplingContext , # type: ignore
2024-04-26 01:32:44 +00:00
_normalize_embedding , # type: ignore
2024-01-17 14:14:00 +00:00
)
2024-02-06 02:52:12 +00:00
from . _logger import set_verbose
2024-02-21 21:25:10 +00:00
from . _utils import suppress_stdout_stderr
2023-07-18 23:27:41 +00:00
2023-09-29 02:42:03 +00:00
2023-03-23 09:33:06 +00:00
class Llama :
2023-03-24 22:57:59 +00:00
""" High-level Python wrapper for a llama.cpp model. """
2023-09-14 03:00:43 +00:00
__backend_initialized = False
2023-03-23 09:33:06 +00:00
def __init__ (
self ,
model_path : str ,
2023-09-14 01:19:47 +00:00
* ,
2023-09-29 02:42:03 +00:00
# Model Params
2023-06-08 17:19:23 +00:00
n_gpu_layers : int = 0 ,
2024-02-25 21:53:58 +00:00
split_mode : int = llama_cpp . LLAMA_SPLIT_MODE_LAYER ,
2023-09-14 01:20:26 +00:00
main_gpu : int = 0 ,
tensor_split : Optional [ List [ float ] ] = None ,
2023-09-29 02:42:03 +00:00
vocab_only : bool = False ,
use_mmap : bool = True ,
use_mlock : bool = False ,
2024-04-28 03:42:19 +00:00
kv_overrides : Optional [ Dict [ str , Union [ bool , int , float , str ] ] ] = None ,
2023-09-29 02:42:03 +00:00
# Context Params
seed : int = llama_cpp . LLAMA_DEFAULT_SEED ,
n_ctx : int = 512 ,
n_batch : int = 512 ,
n_threads : Optional [ int ] = None ,
n_threads_batch : Optional [ int ] = None ,
2024-02-25 21:53:58 +00:00
rope_scaling_type : Optional [ int ] = llama_cpp . LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
2024-03-14 14:04:57 +00:00
pooling_type : int = llama_cpp . LLAMA_POOLING_TYPE_UNSPECIFIED ,
2023-09-29 20:03:57 +00:00
rope_freq_base : float = 0.0 ,
rope_freq_scale : float = 0.0 ,
2023-11-03 15:34:50 +00:00
yarn_ext_factor : float = - 1.0 ,
2023-11-02 17:40:20 +00:00
yarn_attn_factor : float = 1.0 ,
yarn_beta_fast : float = 32.0 ,
yarn_beta_slow : float = 1.0 ,
yarn_orig_ctx : int = 0 ,
2023-03-23 09:33:06 +00:00
logits_all : bool = False ,
2023-03-25 20:26:23 +00:00
embedding : bool = False ,
2024-01-18 16:08:57 +00:00
offload_kqv : bool = True ,
2024-04-30 13:29:16 +00:00
flash_attn : bool = False ,
2023-09-29 02:42:03 +00:00
# Sampling Params
2023-04-01 17:01:27 +00:00
last_n_tokens_size : int = 64 ,
2023-09-29 02:42:03 +00:00
# LoRA Params
2023-06-08 17:19:23 +00:00
lora_base : Optional [ str ] = None ,
2023-09-29 02:42:03 +00:00
lora_scale : float = 1.0 ,
2023-06-08 17:19:23 +00:00
lora_path : Optional [ str ] = None ,
2023-09-29 02:42:03 +00:00
# Backend Params
2024-02-17 05:37:51 +00:00
numa : Union [ bool , int ] = False ,
2023-09-29 23:52:04 +00:00
# Chat Format Params
2024-01-29 19:22:23 +00:00
chat_format : Optional [ str ] = None ,
2023-11-08 03:48:51 +00:00
chat_handler : Optional [ llama_chat_format . LlamaChatCompletionHandler ] = None ,
2024-01-31 19:08:14 +00:00
# Speculative Decoding
draft_model : Optional [ LlamaDraftModel ] = None ,
2024-02-08 01:07:03 +00:00
# Tokenizer Override
tokenizer : Optional [ BaseLlamaTokenizer ] = None ,
2024-04-01 14:19:28 +00:00
# KV cache quantization
type_k : Optional [ int ] = None ,
type_v : Optional [ int ] = None ,
2023-09-29 02:42:03 +00:00
# Misc
2023-04-04 17:09:24 +00:00
verbose : bool = True ,
2023-09-29 02:42:03 +00:00
# Extra Params
* * kwargs , # type: ignore
2023-04-01 17:01:27 +00:00
) :
2023-03-24 22:57:59 +00:00
""" Load a llama.cpp model from `model_path`.
2023-12-16 23:59:26 +00:00
2023-11-23 04:10:04 +00:00
Examples :
Basic usage
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . )
>> > print ( model ( " The quick brown fox jumps " , stop = [ " . " ] ) [ " choices " ] [ 0 ] [ " text " ] )
the lazy dog
Loading a chat model
>> > import llama_cpp
>> > model = llama_cpp . Llama (
. . . model_path = " path/to/model " ,
. . . chat_format = " llama-2 " ,
. . . )
>> > print ( model . create_chat_completion (
. . . messages = [ {
. . . " role " : " user " ,
. . . " content " : " what is the meaning of life? "
. . . } ]
. . . ) )
2023-03-24 22:57:59 +00:00
Args :
2023-03-25 16:33:18 +00:00
model_path : Path to the model .
2023-08-12 10:41:47 +00:00
n_gpu_layers : Number of layers to offload to GPU ( - ngl ) . If - 1 , all layers are offloaded .
2024-01-15 17:49:20 +00:00
split_mode : How to split the model across GPUs . See llama_cpp . LLAMA_SPLIT_ * for options .
main_gpu : main_gpu interpretation depends on split_mode : LLAMA_SPLIT_NONE : the GPU that is used for the entire model . LLAMA_SPLIT_ROW : the GPU that is used for small tensors and intermediate results . LLAMA_SPLIT_LAYER : ignored
2023-11-02 17:40:20 +00:00
tensor_split : How split tensors should be distributed across GPUs . If None , the model is not split .
vocab_only : Only load the vocabulary no weights .
use_mmap : Use mmap if possible .
use_mlock : Force the system to keep the model in RAM .
2024-01-15 17:29:29 +00:00
kv_overrides : Key - value overrides for the model .
2023-11-26 20:56:40 +00:00
seed : RNG seed , - 1 for random
n_ctx : Text context , 0 = from model
n_batch : Prompt processing maximum batch size
n_threads : Number of threads to use for generation
n_threads_batch : Number of threads to use for batch processing
rope_scaling_type : RoPE scaling type , from ` enum llama_rope_scaling_type ` . ref : https : / / github . com / ggerganov / llama . cpp / pull / 2054
2024-03-14 13:17:41 +00:00
pooling_type : Pooling type , from ` enum llama_pooling_type ` .
2023-11-26 20:56:40 +00:00
rope_freq_base : RoPE base frequency , 0 = from model
rope_freq_scale : RoPE frequency scaling factor , 0 = from model
yarn_ext_factor : YaRN extrapolation mix factor , negative = from model
yarn_attn_factor : YaRN magnitude scaling factor
yarn_beta_fast : YaRN low correction dim
yarn_beta_slow : YaRN high correction dim
yarn_orig_ctx : YaRN original context size
logits_all : Return logits for all tokens , not just the last token . Must be True for completion to return logprobs .
2023-03-25 20:26:23 +00:00
embedding : Embedding mode only .
2023-12-18 20:36:09 +00:00
offload_kqv : Offload K , Q , V to GPU .
2024-04-30 13:29:16 +00:00
flash_attn : Use flash attention .
2023-04-01 17:01:27 +00:00
last_n_tokens_size : Maximum number of tokens to keep in the last_n_tokens deque .
2023-04-18 14:20:46 +00:00
lora_base : Optional path to base model , useful if using a quantized base model and you want to apply LoRA to an f16 model .
2023-04-18 05:43:44 +00:00
lora_path : Path to a LoRA file to apply to the model .
2024-02-17 05:37:51 +00:00
numa : numa policy
2023-09-29 23:52:04 +00:00
chat_format : String specifying the chat format to use when calling create_chat_completion .
2023-11-08 03:48:51 +00:00
chat_handler : Optional chat handler to use when calling create_chat_completion .
2024-01-31 19:08:14 +00:00
draft_model : Optional draft model to use for speculative decoding .
2024-02-08 01:07:03 +00:00
tokenizer : Optional tokenizer to override the default tokenizer from llama . cpp .
2023-04-04 17:09:24 +00:00
verbose : Print verbose output to stderr .
2024-04-01 14:19:28 +00:00
type_k : KV cache data type for K ( default : f16 )
type_v : KV cache data type for V ( default : f16 )
2023-03-24 22:57:59 +00:00
Raises :
ValueError : If the model path does not exist .
Returns :
A Llama instance .
"""
2023-04-04 17:09:24 +00:00
self . verbose = verbose
2023-09-14 03:00:43 +00:00
2024-02-06 02:52:12 +00:00
set_verbose ( verbose )
2023-09-14 03:00:43 +00:00
if not Llama . __backend_initialized :
2024-02-12 20:56:07 +00:00
with suppress_stdout_stderr ( disable = verbose ) :
2024-02-17 05:37:51 +00:00
llama_cpp . llama_backend_init ( )
2023-09-14 03:00:43 +00:00
Llama . __backend_initialized = True
2024-02-17 05:37:51 +00:00
if isinstance ( numa , bool ) :
2024-02-21 21:25:10 +00:00
self . numa = (
llama_cpp . GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp . GGML_NUMA_STRATEGY_DISABLED
)
2024-02-17 06:02:33 +00:00
else :
self . numa = numa
2024-02-17 05:37:51 +00:00
if self . numa != llama_cpp . GGML_NUMA_STRATEGY_DISABLED :
with suppress_stdout_stderr ( disable = verbose ) :
llama_cpp . llama_numa_init ( self . numa )
2023-03-23 09:33:06 +00:00
self . model_path = model_path
2023-09-29 02:42:03 +00:00
# Model Params
self . model_params = llama_cpp . llama_model_default_params ( )
self . model_params . n_gpu_layers = (
0x7FFFFFFF if n_gpu_layers == - 1 else n_gpu_layers
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
2024-01-15 17:49:20 +00:00
self . model_params . split_mode = split_mode
2023-09-29 02:42:03 +00:00
self . model_params . main_gpu = main_gpu
2023-07-15 19:11:01 +00:00
self . tensor_split = tensor_split
2023-12-22 20:12:27 +00:00
self . _c_tensor_split = None
2023-07-15 19:11:01 +00:00
if self . tensor_split is not None :
2023-10-15 17:51:51 +00:00
if len ( self . tensor_split ) > llama_cpp . LLAMA_MAX_DEVICES :
2023-11-06 14:16:36 +00:00
raise ValueError (
f " Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES= { llama_cpp . LLAMA_MAX_DEVICES } "
)
2023-07-18 23:27:41 +00:00
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
2023-09-14 00:00:42 +00:00
FloatArray = ctypes . c_float * llama_cpp . LLAMA_MAX_DEVICES
2023-07-18 23:27:41 +00:00
self . _c_tensor_split = FloatArray (
2023-09-29 02:42:03 +00:00
* tensor_split # type: ignore
2023-07-18 23:27:41 +00:00
) # keep a reference to the array so it is not gc'd
2023-09-29 02:42:03 +00:00
self . model_params . tensor_split = self . _c_tensor_split
self . model_params . vocab_only = vocab_only
self . model_params . use_mmap = use_mmap if lora_path is None else False
self . model_params . use_mlock = use_mlock
2023-07-15 19:11:01 +00:00
2024-01-24 03:00:38 +00:00
# kv_overrides is the original python dict
2024-01-15 17:29:29 +00:00
self . kv_overrides = kv_overrides
if kv_overrides is not None :
2024-01-24 03:00:38 +00:00
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
2024-01-24 03:08:27 +00:00
kvo_array_len = len ( kv_overrides ) + 1 # for sentinel element
self . _kv_overrides_array = (
llama_cpp . llama_model_kv_override * kvo_array_len
) ( )
2024-01-24 03:00:38 +00:00
for i , ( k , v ) in enumerate ( kv_overrides . items ( ) ) :
2024-01-24 03:08:27 +00:00
self . _kv_overrides_array [ i ] . key = k . encode ( " utf-8 " )
2024-01-24 03:28:03 +00:00
if isinstance ( v , bool ) :
2024-02-25 21:53:58 +00:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_BOOL
2024-01-24 03:28:03 +00:00
self . _kv_overrides_array [ i ] . value . bool_value = v
elif isinstance ( v , int ) :
2024-02-25 21:53:58 +00:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_INT
2024-01-15 17:29:29 +00:00
self . _kv_overrides_array [ i ] . value . int_value = v
elif isinstance ( v , float ) :
2024-02-25 21:53:58 +00:00
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_FLOAT
2024-01-15 17:29:29 +00:00
self . _kv_overrides_array [ i ] . value . float_value = v
2024-04-28 03:42:19 +00:00
elif isinstance ( v , str ) : # type: ignore
v_bytes = v . encode ( " utf-8 " )
if len ( v_bytes ) > 128 : # TODO: Make this a constant
raise ValueError ( f " Value for { k } is too long: { v } " )
v_bytes = v_bytes . ljust ( 128 , b " \0 " )
self . _kv_overrides_array [ i ] . tag = llama_cpp . LLAMA_KV_OVERRIDE_TYPE_STR
2024-05-03 23:07:50 +00:00
# copy min(v_bytes, 128) to str_value
ctypes . memmove (
self . _kv_overrides_array [ i ] . value . str_value ,
v_bytes ,
min ( len ( v_bytes ) , 128 ) ,
)
2024-01-15 17:29:29 +00:00
else :
raise ValueError ( f " Unknown value type for { k } : { v } " )
2024-02-21 21:25:10 +00:00
self . _kv_overrides_array [ - 1 ] . key = (
b " \0 " # ensure sentinel element is zeroed
)
2024-01-15 17:29:29 +00:00
self . model_params . kv_overrides = self . _kv_overrides_array
2023-09-29 02:42:03 +00:00
self . n_batch = min ( n_ctx , n_batch ) # ???
self . n_threads = n_threads or max ( multiprocessing . cpu_count ( ) / / 2 , 1 )
2024-04-17 14:04:33 +00:00
self . n_threads_batch = n_threads_batch or multiprocessing . cpu_count ( )
2024-02-21 21:25:10 +00:00
2023-09-29 02:42:03 +00:00
# Context Params
self . context_params = llama_cpp . llama_context_default_params ( )
self . context_params . seed = seed
self . context_params . n_ctx = n_ctx
self . context_params . n_batch = self . n_batch
self . context_params . n_threads = self . n_threads
self . context_params . n_threads_batch = self . n_threads_batch
2023-11-02 17:40:20 +00:00
self . context_params . rope_scaling_type = (
2023-11-06 14:16:36 +00:00
rope_scaling_type
if rope_scaling_type is not None
2024-02-25 21:53:58 +00:00
else llama_cpp . LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
2023-11-02 17:40:20 +00:00
)
2024-03-14 13:17:41 +00:00
self . context_params . pooling_type = pooling_type
2023-09-29 20:03:57 +00:00
self . context_params . rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
self . context_params . rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
2023-11-02 17:40:20 +00:00
self . context_params . yarn_ext_factor = (
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
)
self . context_params . yarn_attn_factor = (
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
)
self . context_params . yarn_beta_fast = (
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
)
self . context_params . yarn_beta_slow = (
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
2023-11-06 14:16:36 +00:00
self . context_params . yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
2024-02-21 21:25:10 +00:00
self . context_params . logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
2024-03-06 06:32:00 +00:00
self . context_params . embeddings = embedding # TODO: Rename to embeddings
2023-12-18 20:36:09 +00:00
self . context_params . offload_kqv = offload_kqv
2024-04-30 13:29:16 +00:00
self . context_params . flash_attn = flash_attn
2024-04-01 14:19:28 +00:00
# KV cache quantization
if type_k is not None :
self . context_params . type_k = type_k
if type_v is not None :
self . context_params . type_v = type_v
2023-09-29 02:42:03 +00:00
# Sampling Params
2023-04-01 17:01:27 +00:00
self . last_n_tokens_size = last_n_tokens_size
2023-03-23 09:33:06 +00:00
2023-09-29 02:42:03 +00:00
self . cache : Optional [ BaseLlamaCache ] = None
2023-03-23 09:33:06 +00:00
2023-04-25 13:00:53 +00:00
self . lora_base = lora_base
2023-09-29 02:42:03 +00:00
self . lora_scale = lora_scale
2023-04-25 13:00:53 +00:00
self . lora_path = lora_path
2023-03-24 19:47:17 +00:00
if not os . path . exists ( model_path ) :
raise ValueError ( f " Model path does not exist: { model_path } " )
2023-11-06 14:16:36 +00:00
self . _model = _LlamaModel (
path_model = self . model_path , params = self . model_params , verbose = self . verbose
)
2024-02-08 01:07:03 +00:00
# Override tokenizer
self . tokenizer_ = tokenizer or LlamaTokenizer ( self )
2023-12-16 23:59:50 +00:00
# Set the default value for the context and correct the batch
if n_ctx == 0 :
n_ctx = self . _model . n_ctx_train ( )
self . n_batch = min ( n_ctx , n_batch )
self . context_params . n_ctx = self . _model . n_ctx_train ( )
self . context_params . n_batch = self . n_batch
2023-03-23 09:33:06 +00:00
2023-11-06 14:16:36 +00:00
self . _ctx = _LlamaContext (
model = self . _model ,
params = self . context_params ,
verbose = self . verbose ,
)
2023-04-25 13:00:53 +00:00
2023-11-06 14:16:36 +00:00
self . _batch = _LlamaBatch (
n_tokens = self . n_batch ,
embd = 0 ,
n_seq_max = self . context_params . n_ctx ,
verbose = self . verbose ,
)
2023-11-03 00:13:57 +00:00
2023-04-19 03:45:25 +00:00
if self . lora_path :
2023-11-06 14:16:36 +00:00
if self . _model . apply_lora_from_file (
self . lora_path ,
2023-09-29 02:42:03 +00:00
self . lora_scale ,
2023-11-06 14:16:36 +00:00
self . lora_base ,
2023-09-14 01:11:52 +00:00
self . n_threads ,
2023-04-18 05:43:44 +00:00
) :
2023-04-19 03:45:25 +00:00
raise RuntimeError (
f " Failed to apply LoRA from lora path: { self . lora_path } to base path: { self . lora_base } "
)
2023-03-23 09:33:06 +00:00
2023-04-04 17:09:24 +00:00
if self . verbose :
print ( llama_cpp . llama_print_system_info ( ) . decode ( " utf-8 " ) , file = sys . stderr )
2023-11-03 06:12:14 +00:00
2023-09-29 23:52:04 +00:00
self . chat_format = chat_format
2023-11-08 03:48:51 +00:00
self . chat_handler = chat_handler
2024-05-09 13:49:09 +00:00
self . _chat_handlers : Dict [ str , llama_chat_format . LlamaChatCompletionHandler ] = { }
2023-04-04 17:09:24 +00:00
2024-01-31 19:08:14 +00:00
self . draft_model = draft_model
2023-05-23 21:56:21 +00:00
self . _n_vocab = self . n_vocab ( )
self . _n_ctx = self . n_ctx ( )
2023-11-06 14:16:36 +00:00
2023-08-24 04:17:00 +00:00
self . _token_nl = self . token_nl ( )
self . _token_eos = self . token_eos ( )
2023-11-06 14:16:36 +00:00
self . _candidates = _LlamaTokenDataArray ( n_vocab = self . _n_vocab )
2023-04-04 17:09:24 +00:00
2023-06-29 04:40:47 +00:00
self . n_tokens = 0
self . input_ids : npt . NDArray [ np . intc ] = np . ndarray ( ( n_ctx , ) , dtype = np . intc )
self . scores : npt . NDArray [ np . single ] = np . ndarray (
( n_ctx , self . _n_vocab ) , dtype = np . single
)
2024-01-24 03:08:27 +00:00
self . _mirostat_mu = ctypes . c_float (
2.0 * 5.0
) # TODO: Move this to sampling context
2024-01-19 13:31:59 +00:00
2024-01-19 15:46:03 +00:00
try :
self . metadata = self . _model . metadata ( )
except Exception as e :
self . metadata = { }
if self . verbose :
print ( f " Failed to load metadata: { e } " , file = sys . stderr )
2024-01-24 03:08:27 +00:00
2024-01-19 15:46:03 +00:00
if self . verbose :
print ( f " Model metadata: { self . metadata } " , file = sys . stderr )
2024-05-14 13:44:09 +00:00
eos_token_id = self . token_eos ( )
bos_token_id = self . token_bos ( )
2024-05-09 13:49:09 +00:00
2024-05-16 04:37:27 +00:00
eos_token = self . _model . token_get_text ( eos_token_id ) if eos_token_id != - 1 else " "
bos_token = self . _model . token_get_text ( bos_token_id ) if bos_token_id != - 1 else " "
2024-05-09 13:49:09 +00:00
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
template_choices = dict ( ( name [ 10 : ] , template ) for name , template in self . metadata . items ( ) if name . startswith ( " tokenizer.chat_template. " ) )
if " tokenizer.chat_template " in self . metadata :
template_choices [ " chat_template.default " ] = self . metadata [ " tokenizer.chat_template " ]
if self . verbose and template_choices :
print ( f " Available chat formats from metadata: { ' , ' . join ( template_choices . keys ( ) ) } " , file = sys . stderr )
for name , template in template_choices . items ( ) :
self . _chat_handlers [ name ] = llama_chat_format . Jinja2ChatFormatter (
template = template ,
eos_token = eos_token ,
bos_token = bos_token ,
stop_token_ids = [ eos_token_id ] ,
) . to_chat_handler ( )
2024-02-21 21:25:10 +00:00
if (
self . chat_format is None
and self . chat_handler is None
2024-05-09 13:49:09 +00:00
and " chat_template.default " in template_choices
2024-02-21 21:25:10 +00:00
) :
chat_format = llama_chat_format . guess_chat_format_from_gguf_metadata (
self . metadata
)
2024-01-29 19:22:23 +00:00
if chat_format is not None :
self . chat_format = chat_format
if self . verbose :
print ( f " Guessed chat format: { chat_format } " , file = sys . stderr )
else :
if self . verbose :
2024-05-09 13:49:09 +00:00
print ( f " Using gguf chat template: { template_choices [ ' chat_template.default ' ] } " , file = sys . stderr )
2024-01-29 19:22:23 +00:00
print ( f " Using chat eos_token: { eos_token } " , file = sys . stderr )
print ( f " Using chat bos_token: { bos_token } " , file = sys . stderr )
2024-05-09 13:49:09 +00:00
self . chat_format = " chat_template.default "
2024-01-29 19:22:23 +00:00
if self . chat_format is None and self . chat_handler is None :
self . chat_format = " llama-2 "
2024-03-01 18:10:25 +00:00
if self . verbose :
2024-05-08 06:19:35 +00:00
print ( f " Using fallback chat format: { self . chat_format } " , file = sys . stderr )
2024-01-29 19:22:23 +00:00
2023-11-06 14:16:36 +00:00
@property
def ctx ( self ) - > llama_cpp . llama_context_p :
assert self . _ctx . ctx is not None
return self . _ctx . ctx
@property
def model ( self ) - > llama_cpp . llama_model_p :
assert self . _model . model is not None
return self . _model . model
2023-06-29 04:40:47 +00:00
@property
def _input_ids ( self ) - > npt . NDArray [ np . intc ] :
return self . input_ids [ : self . n_tokens ]
@property
def _scores ( self ) - > npt . NDArray [ np . single ] :
return self . scores [ : self . n_tokens , : ]
@property
def eval_tokens ( self ) - > Deque [ int ] :
return deque ( self . input_ids [ : self . n_tokens ] . tolist ( ) , maxlen = self . _n_ctx )
@property
def eval_logits ( self ) - > Deque [ List [ float ] ] :
return deque (
self . scores [ : self . n_tokens , : ] . tolist ( ) ,
2023-09-30 20:02:35 +00:00
maxlen = self . _n_ctx if self . context_params . logits_all else 1 ,
2023-06-29 04:40:47 +00:00
)
2023-05-26 20:12:45 +00:00
2023-11-06 14:16:36 +00:00
def tokenize (
self , text : bytes , add_bos : bool = True , special : bool = False
) - > List [ int ] :
2023-03-28 05:45:37 +00:00
""" Tokenize a string.
Args :
text : The utf - 8 encoded string to tokenize .
2023-04-01 17:01:27 +00:00
Raises :
RuntimeError : If the tokenization failed .
2023-03-28 05:45:37 +00:00
Returns :
A list of tokens .
"""
2024-02-08 01:07:03 +00:00
return self . tokenizer_ . tokenize ( text , add_bos , special )
2023-03-28 05:45:37 +00:00
2024-02-21 21:25:10 +00:00
def detokenize (
self , tokens : List [ int ] , prev_tokens : Optional [ List [ int ] ] = None
) - > bytes :
2023-03-28 05:45:37 +00:00
""" Detokenize a list of tokens.
Args :
tokens : The list of tokens to detokenize .
2024-02-08 01:07:03 +00:00
prev_tokens : The list of previous tokens . Offset mapping will be performed if provided
2023-03-28 05:45:37 +00:00
Returns :
The detokenized string .
"""
2024-02-23 17:23:24 +00:00
return self . tokenizer_ . detokenize ( tokens , prev_tokens = prev_tokens )
2023-03-28 05:45:37 +00:00
2023-06-08 17:19:23 +00:00
def set_cache ( self , cache : Optional [ BaseLlamaCache ] ) :
2023-04-15 16:03:09 +00:00
""" Set the cache.
Args :
cache : The cache to set .
"""
2023-04-24 23:54:41 +00:00
self . cache = cache
2023-04-15 16:03:09 +00:00
2023-11-08 16:09:41 +00:00
def set_seed ( self , seed : int ) :
""" Set the random seed.
Args :
seed : The random seed .
"""
assert self . _ctx . ctx is not None
llama_cpp . llama_set_rng_seed ( self . _ctx . ctx , seed )
2023-04-02 04:02:47 +00:00
def reset ( self ) :
""" Reset the model state. """
2023-06-29 04:40:47 +00:00
self . n_tokens = 0
2023-04-02 04:02:47 +00:00
2023-05-19 15:59:33 +00:00
def eval ( self , tokens : Sequence [ int ] ) :
2023-04-02 04:02:47 +00:00
""" Evaluate a list of tokens.
Args :
tokens : The list of tokens to evaluate .
"""
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
assert self . _batch . batch is not None
2023-11-10 09:41:19 +00:00
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
2023-04-02 04:02:47 +00:00
for i in range ( 0 , len ( tokens ) , self . n_batch ) :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
2023-11-10 09:41:19 +00:00
n_past = self . n_tokens
2023-04-24 19:47:54 +00:00
n_tokens = len ( batch )
2023-11-06 14:16:36 +00:00
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
2023-04-02 04:02:47 +00:00
)
2023-11-06 14:16:36 +00:00
self . _ctx . decode ( self . _batch )
2023-05-01 18:47:55 +00:00
# Save tokens
2023-11-10 10:15:41 +00:00
self . input_ids [ n_past : n_past + n_tokens ] = batch
2023-05-01 18:47:55 +00:00
# Save logits
2024-04-03 19:30:31 +00:00
if self . context_params . logits_all :
rows = n_tokens
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past : n_past + n_tokens , : ] . reshape ( - 1 ) [ : : ] = logits
else :
rows = 1
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past + n_tokens - 1 , : ] . reshape ( - 1 ) [ : : ] = logits
2023-06-29 04:40:47 +00:00
# Update n_tokens
self . n_tokens + = n_tokens
2023-05-01 18:47:55 +00:00
2023-11-06 14:16:36 +00:00
def sample (
2023-06-08 17:19:23 +00:00
self ,
2023-11-06 14:16:36 +00:00
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-06 14:16:36 +00:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_eta : float = 0.1 ,
mirostat_tau : float = 5.0 ,
2023-06-08 17:19:23 +00:00
penalize_nl : bool = True ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2024-01-31 19:08:14 +00:00
idx : Optional [ int ] = None ,
2023-05-01 18:47:55 +00:00
) :
2023-11-06 14:16:36 +00:00
""" Sample a token from the model.
Args :
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
Returns :
The sampled token .
"""
assert self . _ctx is not None
2023-06-29 04:40:47 +00:00
assert self . n_tokens > 0
2024-01-31 19:08:14 +00:00
if idx is None :
logits : npt . NDArray [ np . single ] = self . _scores [ - 1 , : ]
else :
logits = self . _scores [ idx , : ]
2023-05-24 19:55:44 +00:00
2023-05-25 18:04:54 +00:00
if logits_processor is not None :
2024-01-31 19:08:14 +00:00
logits [ : ] = (
logits_processor ( self . _input_ids , logits )
if idx is None
2024-02-21 21:25:10 +00:00
else logits_processor ( self . _input_ids [ : idx + 1 ] , logits )
2024-01-31 19:08:14 +00:00
)
sampling_params = _LlamaSamplingParams (
top_k = top_k ,
top_p = top_p ,
min_p = min_p ,
tfs_z = tfs_z ,
typical_p = typical_p ,
temp = temp ,
penalty_last_n = self . last_n_tokens_size ,
2023-10-24 07:13:32 +00:00
penalty_repeat = repeat_penalty ,
penalty_freq = frequency_penalty ,
penalty_present = presence_penalty ,
2024-01-31 19:08:14 +00:00
mirostat = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
penalize_nl = penalize_nl ,
)
sampling_context = _LlamaSamplingContext (
params = sampling_params ,
grammar = grammar ,
)
sampling_context . prev = list ( self . eval_tokens )
id = sampling_context . sample ( ctx_main = self . _ctx , logits_array = logits )
sampling_context . accept (
ctx_main = self . _ctx ,
id = id ,
apply_grammar = grammar is not None ,
2023-05-09 01:21:25 +00:00
)
2023-08-06 17:21:37 +00:00
return id
2023-04-02 04:02:47 +00:00
2023-04-01 17:01:27 +00:00
def generate (
self ,
2023-06-08 17:19:23 +00:00
tokens : Sequence [ int ] ,
top_k : int = 40 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-06-08 17:19:23 +00:00
temp : float = 0.80 ,
repeat_penalty : float = 1.1 ,
reset : bool = True ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
2024-01-10 07:46:27 +00:00
penalize_nl : bool = True ,
2023-06-08 17:19:23 +00:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-05-19 15:59:33 +00:00
) - > Generator [ int , Optional [ Sequence [ int ] ] , None ] :
2023-04-02 04:02:47 +00:00
""" Create a generator of tokens from a prompt.
2023-04-01 21:36:30 +00:00
2023-04-01 21:39:35 +00:00
Examples :
>> > llama = Llama ( " models/ggml-7b.bin " )
>> > tokens = llama . tokenize ( b " Hello, world! " )
>> > for token in llama . generate ( tokens , top_k = 40 , top_p = 0.95 , temp = 1.0 , repeat_penalty = 1.1 ) :
. . . print ( llama . detokenize ( [ token ] ) )
2023-04-01 21:36:30 +00:00
Args :
tokens : The prompt tokens .
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
2023-04-13 04:28:00 +00:00
reset : Whether to reset the model state .
2023-04-01 21:36:30 +00:00
Yields :
The generated tokens .
"""
2024-01-19 13:31:59 +00:00
# Reset mirostat sampling
self . _mirostat_mu = ctypes . c_float ( 2.0 * mirostat_tau )
# Check for kv cache prefix match
2023-11-06 14:16:36 +00:00
if reset and self . n_tokens > 0 :
2023-05-05 01:58:27 +00:00
longest_prefix = 0
2023-05-27 00:03:31 +00:00
for a , b in zip ( self . _input_ids , tokens [ : - 1 ] ) :
2023-05-05 01:58:27 +00:00
if a == b :
longest_prefix + = 1
else :
break
if longest_prefix > 0 :
if self . verbose :
print ( " Llama.generate: prefix-match hit " , file = sys . stderr )
reset = False
tokens = tokens [ longest_prefix : ]
2023-06-29 04:40:47 +00:00
self . n_tokens = longest_prefix
2023-04-24 23:54:41 +00:00
2024-01-19 13:31:59 +00:00
# Reset the model state
2023-04-13 04:28:00 +00:00
if reset :
self . reset ( )
2023-05-05 01:58:27 +00:00
2024-01-19 13:31:59 +00:00
# Reset the grammar
2023-08-08 19:08:54 +00:00
if grammar is not None :
grammar . reset ( )
2023-08-07 06:16:25 +00:00
2024-01-31 19:08:14 +00:00
sample_idx = self . n_tokens + len ( tokens ) - 1
tokens = list ( tokens )
2024-01-19 13:31:59 +00:00
# Eval and sample
2023-04-01 17:01:27 +00:00
while True :
2023-04-02 04:02:47 +00:00
self . eval ( tokens )
2024-01-31 19:08:14 +00:00
while sample_idx < self . n_tokens :
token = self . sample (
top_k = top_k ,
top_p = top_p ,
min_p = min_p ,
typical_p = typical_p ,
temp = temp ,
repeat_penalty = repeat_penalty ,
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
logits_processor = logits_processor ,
grammar = grammar ,
penalize_nl = penalize_nl ,
idx = sample_idx ,
)
sample_idx + = 1
if stopping_criteria is not None and stopping_criteria (
self . _input_ids , self . _scores [ - 1 , : ]
) :
return
tokens_or_none = yield token
tokens . clear ( )
tokens . append ( token )
if tokens_or_none is not None :
tokens . extend ( tokens_or_none )
if sample_idx < self . n_tokens and token != self . _input_ids [ sample_idx ] :
self . n_tokens = sample_idx
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
break
if self . draft_model is not None :
self . input_ids [ self . n_tokens : self . n_tokens + len ( tokens ) ] = tokens
2024-02-21 21:25:10 +00:00
draft_tokens = self . draft_model (
self . input_ids [ : self . n_tokens + len ( tokens ) ]
)
2024-01-31 19:08:14 +00:00
tokens . extend (
draft_tokens . astype ( int ) [
: self . _n_ctx - self . n_tokens - len ( tokens )
]
)
2023-04-01 17:01:27 +00:00
2023-05-19 23:23:32 +00:00
def create_embedding (
2023-06-08 17:19:23 +00:00
self , input : Union [ str , List [ str ] ] , model : Optional [ str ] = None
2023-09-29 02:42:03 +00:00
) - > CreateEmbeddingResponse :
2023-03-28 08:59:54 +00:00
""" Embed a string.
Args :
2023-04-01 17:01:27 +00:00
input : The utf - 8 encoded string to embed .
2023-03-28 08:59:54 +00:00
Returns :
2023-04-01 17:01:27 +00:00
An embedding object .
2023-03-28 08:59:54 +00:00
"""
2023-11-06 14:16:36 +00:00
assert self . _model . model is not None
2023-05-16 22:07:25 +00:00
model_name : str = model if model is not None else self . model_path
2023-04-04 17:09:24 +00:00
2024-02-15 21:09:48 +00:00
input = input if isinstance ( input , list ) else [ input ]
2024-02-14 09:26:09 +00:00
# get numeric embeddings
2024-04-26 01:32:44 +00:00
embeds : Union [ List [ List [ float ] ] , List [ List [ List [ float ] ] ] ]
2024-02-14 09:26:09 +00:00
total_tokens : int
embeds , total_tokens = self . embed ( input , return_count = True ) # type: ignore
# convert to CreateEmbeddingResponse
data : List [ Embedding ] = [
{
" object " : " embedding " ,
" embedding " : emb ,
" index " : idx ,
}
for idx , emb in enumerate ( embeds )
]
return {
" object " : " list " ,
" data " : data ,
" model " : model_name ,
" usage " : {
" prompt_tokens " : total_tokens ,
" total_tokens " : total_tokens ,
} ,
}
def embed (
self ,
input : Union [ str , List [ str ] ] ,
2024-04-26 01:32:44 +00:00
normalize : bool = False ,
2024-02-14 09:26:09 +00:00
truncate : bool = True ,
return_count : bool = False ,
) :
""" Embed a string.
Args :
input : The utf - 8 encoded string to embed .
Returns :
A list of embeddings
"""
assert self . _ctx . ctx is not None
n_embd = self . n_embd ( )
2024-02-15 20:16:30 +00:00
n_batch = self . n_batch
2024-02-14 09:26:09 +00:00
2024-04-26 01:32:44 +00:00
# get pooling information
pooling_type = self . pooling_type ( )
logits_all = pooling_type == llama_cpp . LLAMA_POOLING_TYPE_NONE
2024-03-06 06:32:00 +00:00
if self . context_params . embeddings == False :
2023-04-05 07:25:37 +00:00
raise RuntimeError (
" Llama model must be created with embedding=True to call this method "
)
2023-04-04 17:09:24 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
llama_cpp . llama_reset_timings ( self . _ctx . ctx )
2023-04-04 17:09:24 +00:00
2023-05-19 23:23:32 +00:00
if isinstance ( input , str ) :
inputs = [ input ]
else :
inputs = input
2023-04-04 17:09:24 +00:00
2024-02-14 09:26:09 +00:00
# reset batch
self . _batch . reset ( )
# decode and fetch embeddings
2024-04-26 01:32:44 +00:00
data : Union [ List [ List [ float ] ] , List [ List [ List [ float ] ] ] ] = [ ]
2024-02-21 21:25:10 +00:00
2024-04-26 01:32:44 +00:00
def decode_batch ( seq_sizes : List [ int ] ) :
2024-02-14 09:26:09 +00:00
assert self . _ctx . ctx is not None
llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
self . _ctx . decode ( self . _batch )
self . _batch . reset ( )
# store embeddings
2024-04-26 01:32:44 +00:00
if pooling_type == llama_cpp . LLAMA_POOLING_TYPE_NONE :
pos : int = 0
for i , size in enumerate ( seq_sizes ) :
ptr = llama_cpp . llama_get_embeddings ( self . _ctx . ctx )
embedding : List [ List [ float ] ] = [
ptr [ pos + j * n_embd : pos + ( j + 1 ) * n_embd ] for j in range ( size )
]
if normalize :
embedding = [ _normalize_embedding ( e ) for e in embedding ]
data . append ( embedding )
pos + = size
else :
for i in range ( len ( seq_sizes ) ) :
ptr = llama_cpp . llama_get_embeddings_seq ( self . _ctx . ctx , i )
embedding : List [ float ] = ptr [ : n_embd ]
if normalize :
embedding = _normalize_embedding ( embedding )
data . append ( embedding )
2024-02-14 09:26:09 +00:00
# init state
2023-05-19 23:23:32 +00:00
total_tokens = 0
2024-04-26 01:32:44 +00:00
s_batch = [ ]
2024-02-14 09:26:09 +00:00
t_batch = 0
2024-02-15 20:16:30 +00:00
p_batch = 0
2024-02-14 09:26:09 +00:00
# accumulate batches and encode
for text in inputs :
tokens = self . tokenize ( text . encode ( " utf-8 " ) )
if truncate :
2024-02-15 20:16:30 +00:00
tokens = tokens [ : n_batch ]
2024-02-14 09:26:09 +00:00
2023-05-19 23:23:32 +00:00
n_tokens = len ( tokens )
total_tokens + = n_tokens
2024-02-14 09:26:09 +00:00
# check for overrun
2024-02-15 20:16:30 +00:00
if n_tokens > n_batch :
2024-02-14 09:26:09 +00:00
raise ValueError (
2024-02-15 20:16:30 +00:00
f " Requested tokens ( { n_tokens } ) exceed batch size of { n_batch } "
2024-02-14 09:26:09 +00:00
)
# time to eval batch
2024-02-15 20:16:30 +00:00
if t_batch + n_tokens > n_batch :
2024-04-26 01:32:44 +00:00
decode_batch ( s_batch )
s_batch = [ ]
2024-02-14 09:26:09 +00:00
t_batch = 0
2024-02-15 20:16:30 +00:00
p_batch = 0
2024-02-14 09:26:09 +00:00
# add to batch
2024-04-26 01:32:44 +00:00
self . _batch . add_sequence ( tokens , p_batch , logits_all )
# update batch stats
s_batch . append ( n_tokens )
2024-02-14 09:26:09 +00:00
t_batch + = n_tokens
2024-02-15 20:16:30 +00:00
p_batch + = 1
2024-02-14 09:26:09 +00:00
# hanlde last batch
2024-04-26 01:32:44 +00:00
decode_batch ( s_batch )
2024-02-14 09:26:09 +00:00
2023-05-22 01:30:03 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
llama_cpp . llama_print_timings ( self . _ctx . ctx )
2023-05-19 23:23:32 +00:00
2024-02-14 09:26:09 +00:00
output = data [ 0 ] if isinstance ( input , str ) else data
2023-03-28 06:42:22 +00:00
2024-02-14 09:26:09 +00:00
llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
self . reset ( )
2023-04-03 22:46:19 +00:00
2024-02-14 09:26:09 +00:00
if return_count :
return output , total_tokens
else :
return output
2023-04-03 22:46:19 +00:00
2023-04-01 17:01:27 +00:00
def _create_completion (
2023-03-23 09:33:06 +00:00
self ,
2023-11-08 03:48:51 +00:00
prompt : Union [ str , List [ int ] ] ,
2023-03-23 09:33:06 +00:00
suffix : Optional [ str ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-03-23 09:33:06 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-23 19:51:05 +00:00
logprobs : Optional [ int ] = None ,
2023-03-23 09:33:06 +00:00
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-23 09:33:06 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
2023-03-28 08:03:57 +00:00
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [
Iterator [ CreateCompletionResponse ] , Iterator [ CreateCompletionStreamResponse ]
] :
2023-11-06 14:16:36 +00:00
assert self . _ctx is not None
2023-11-01 22:52:50 +00:00
assert suffix is None or suffix . __class__ is str
2023-05-24 20:02:06 +00:00
2023-04-15 15:39:21 +00:00
completion_id : str = f " cmpl- { str ( uuid . uuid4 ( ) ) } "
created : int = int ( time . time ( ) )
2024-05-14 13:44:09 +00:00
prefix_token_id : int = self . _model . token_prefix ( )
middle_token_id : int = self . _model . token_middle ( )
suffix_token_id : int = self . _model . token_suffix ( )
2023-11-21 03:50:59 +00:00
# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
completion_tokens : List [ int ] = [ ] if len ( prompt ) > 0 else [ self . token_bos ( ) ]
2023-04-01 17:01:27 +00:00
# Add blank space to start of prompt to match OG llama tokenizer
2023-09-29 02:42:03 +00:00
prompt_tokens : List [ int ] = (
2023-11-08 16:09:41 +00:00
(
2024-05-08 06:26:22 +00:00
[ prefix_token_id ]
if prefix_token_id > = 0 and suffix is not None
else [ ]
)
+
(
(
self . tokenize ( prompt . encode ( " utf-8 " ) , add_bos = ( prefix_token_id < 0 or suffix is None ) , special = ( prefix_token_id < 0 or suffix is None ) )
if prompt != " "
else (
[ ]
if prefix_token_id > = 0 and suffix is not None
else [ self . token_bos ( ) ]
)
)
if isinstance ( prompt , str )
else prompt
)
+
(
(
[ suffix_token_id ]
+
(
self . tokenize ( suffix . encode ( " utf-8 " ) , add_bos = False , special = False )
if suffix
else [ ]
)
)
if suffix_token_id > = 0 and suffix is not None
else [ ]
)
+
(
[ middle_token_id ]
if middle_token_id > = 0 and suffix is not None
else [ ]
2023-11-08 16:09:41 +00:00
)
)
2023-04-15 15:39:21 +00:00
text : bytes = b " "
2023-05-18 15:35:59 +00:00
returned_tokens : int = 0
2023-05-19 15:59:33 +00:00
stop = (
stop if isinstance ( stop , list ) else [ stop ] if isinstance ( stop , str ) else [ ]
)
2023-05-16 22:07:25 +00:00
model_name : str = model if model is not None else self . model_path
2023-03-23 09:33:06 +00:00
2023-11-21 08:59:46 +00:00
# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None :
logit_bias_map = { int ( k ) : float ( v ) for k , v in logit_bias . items ( ) }
def logit_bias_processor (
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
) - > npt . NDArray [ np . single ] :
new_scores = np . copy (
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id , score in logit_bias_map . items ( ) :
new_scores [ input_id ] = score + scores [ input_id ]
return new_scores
_logit_bias_processor = LogitsProcessorList ( [ logit_bias_processor ] )
if logits_processor is None :
logits_processor = _logit_bias_processor
else :
logits_processor = logits_processor . extend ( _logit_bias_processor )
2023-04-04 17:09:24 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
self . _ctx . reset_timings ( )
2023-04-04 17:09:24 +00:00
2023-11-06 14:16:36 +00:00
if len ( prompt_tokens ) > = self . _n_ctx :
2023-03-23 09:33:06 +00:00
raise ValueError (
2023-11-08 03:48:51 +00:00
f " Requested tokens ( { len ( prompt_tokens ) } ) exceed context window of { llama_cpp . llama_n_ctx ( self . ctx ) } "
2023-03-23 09:33:06 +00:00
)
2023-11-10 07:49:27 +00:00
if max_tokens is None or max_tokens < = 0 :
2023-07-09 22:13:29 +00:00
# Unlimited, depending on n_ctx.
2023-11-06 14:16:36 +00:00
max_tokens = self . _n_ctx - len ( prompt_tokens )
2023-07-09 22:13:29 +00:00
2023-06-09 14:57:36 +00:00
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (
max_tokens
if max_tokens + len ( prompt_tokens ) < self . _n_ctx
else ( self . _n_ctx - len ( prompt_tokens ) )
)
2023-04-01 17:01:27 +00:00
if stop != [ ] :
2023-04-02 07:59:19 +00:00
stop_sequences = [ s . encode ( " utf-8 " ) for s in stop ]
2023-04-01 17:01:27 +00:00
else :
2023-04-02 07:59:19 +00:00
stop_sequences = [ ]
2023-03-24 18:33:38 +00:00
2023-09-30 20:02:35 +00:00
if logprobs is not None and self . context_params . logits_all is False :
2023-04-12 18:05:11 +00:00
raise ValueError (
" logprobs is not supported for models created with logits_all=False "
)
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-07 23:31:26 +00:00
try :
cache_item = self . cache [ prompt_tokens ]
cache_prefix_len = Llama . longest_token_prefix (
2023-05-27 00:12:05 +00:00
cache_item . input_ids . tolist ( ) , prompt_tokens
2023-05-07 23:31:26 +00:00
)
eval_prefix_len = Llama . longest_token_prefix (
2023-05-27 00:12:05 +00:00
self . _input_ids . tolist ( ) , prompt_tokens
2023-05-07 23:31:26 +00:00
)
if cache_prefix_len > eval_prefix_len :
self . load_state ( cache_item )
if self . verbose :
print ( " Llama._create_completion: cache hit " , file = sys . stderr )
except KeyError :
if self . verbose :
print ( " Llama._create_completion: cache miss " , file = sys . stderr )
2023-11-08 16:09:41 +00:00
2023-11-08 04:37:28 +00:00
if seed is not None :
self . _ctx . set_rng_seed ( seed )
2023-04-15 16:03:09 +00:00
2023-04-12 18:05:11 +00:00
finish_reason = " length "
2023-04-28 10:50:30 +00:00
multibyte_fix = 0
2023-04-01 17:01:27 +00:00
for token in self . generate (
prompt_tokens ,
top_k = top_k ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 17:01:27 +00:00
temp = temperature ,
2023-06-08 17:19:23 +00:00
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 17:01:27 +00:00
repeat_penalty = repeat_penalty ,
2023-06-08 17:19:23 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-03-28 08:03:57 +00:00
) :
2024-04-22 04:35:47 +00:00
assert self . _model . model is not None
if llama_cpp . llama_token_is_eog ( self . _model . model , token ) :
2024-02-23 17:23:24 +00:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-04-24 23:54:41 +00:00
2023-03-28 05:45:37 +00:00
completion_tokens . append ( token )
2023-03-23 09:33:06 +00:00
2024-02-23 17:23:24 +00:00
all_text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-04-28 11:16:18 +00:00
# Contains multi-byte UTF8
2023-05-01 18:47:55 +00:00
for k , char in enumerate ( all_text [ - 3 : ] ) :
2023-04-28 11:16:18 +00:00
k = 3 - k
2023-05-01 18:47:55 +00:00
for num , pattern in [ ( 2 , 192 ) , ( 3 , 224 ) , ( 4 , 240 ) ] :
2023-04-28 11:16:18 +00:00
# Bitwise AND check
2023-05-01 18:47:55 +00:00
if num > k and pattern & char == pattern :
2023-04-28 11:16:18 +00:00
multibyte_fix = num - k
2023-04-28 10:50:30 +00:00
# Stop incomplete bytes from passing
2023-05-01 18:47:55 +00:00
if multibyte_fix > 0 :
2023-04-28 10:50:30 +00:00
multibyte_fix - = 1
continue
2023-04-02 07:59:19 +00:00
any_stop = [ s for s in stop_sequences if s in all_text ]
2023-03-23 09:33:06 +00:00
if len ( any_stop ) > 0 :
first_stop = any_stop [ 0 ]
2023-04-02 07:59:19 +00:00
text = all_text [ : all_text . index ( first_stop ) ]
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-03-28 08:03:57 +00:00
if stream :
2023-05-27 00:23:49 +00:00
remaining_tokens = completion_tokens [ returned_tokens : ]
2024-02-23 17:23:24 +00:00
remaining_text = self . detokenize ( remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-27 00:23:49 +00:00
remaining_length = len ( remaining_text )
2023-04-02 07:59:19 +00:00
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
2023-05-19 06:20:27 +00:00
first_stop_position = 0
2023-04-02 07:59:19 +00:00
for s in stop_sequences :
2023-05-27 00:23:49 +00:00
for i in range ( min ( len ( s ) , remaining_length ) , 0 , - 1 ) :
if remaining_text . endswith ( s [ : i ] ) :
2023-05-19 06:20:27 +00:00
if i > first_stop_position :
first_stop_position = i
2023-03-28 08:03:57 +00:00
break
2023-05-18 15:35:59 +00:00
2023-05-19 06:20:27 +00:00
token_end_position = 0
2023-08-09 14:04:35 +00:00
if logprobs is not None :
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2024-02-23 17:23:24 +00:00
token_end_position + = len ( self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) )
2023-08-09 14:04:35 +00:00
# Check if stop sequence is in the token
2023-09-29 02:42:03 +00:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 14:04:35 +00:00
break
2024-02-23 17:23:24 +00:00
token_str = self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2023-05-19 06:20:27 +00:00
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
2024-02-23 17:23:24 +00:00
self . detokenize ( completion_tokens [ : returned_tokens ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2023-12-22 05:03:29 +00:00
" utf-8 " , errors = " ignore "
)
2023-05-19 06:20:27 +00:00
)
token_offset = len ( prompt_tokens ) + returned_tokens
2023-12-18 19:28:12 +00:00
logits = self . _scores [ token_offset - 1 , : ]
2023-12-18 23:40:36 +00:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 06:20:27 +00:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 15:59:33 +00:00
self . detokenize ( [ i ] ) . decode (
2023-05-19 06:20:27 +00:00
" utf-8 " , errors = " ignore "
) : logprob
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
2024-02-09 07:02:13 +00:00
" tokens " : [
2024-02-23 17:23:24 +00:00
self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2024-02-09 07:02:13 +00:00
" utf-8 " , errors = " ignore "
)
] ,
2023-05-19 06:20:27 +00:00
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 06:20:27 +00:00
" top_logprobs " : [ top_logprob ] ,
}
2023-08-09 14:04:35 +00:00
returned_tokens + = 1
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
2024-02-23 17:23:24 +00:00
" text " : self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) . decode (
2024-02-09 07:02:13 +00:00
" utf-8 " , errors = " ignore "
) ,
2023-08-09 14:04:35 +00:00
" index " : 0 ,
" logprobs " : logprobs_or_none ,
" finish_reason " : None ,
}
] ,
}
else :
while len ( remaining_tokens ) > 0 :
decode_success = False
for i in range ( 1 , len ( remaining_tokens ) + 1 ) :
try :
2024-02-23 17:23:24 +00:00
bs = self . detokenize ( remaining_tokens [ : i ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-09-29 02:42:03 +00:00
ts = bs . decode ( " utf-8 " )
2023-08-09 14:04:35 +00:00
decode_success = True
break
except UnicodeError :
pass
2023-08-29 11:21:59 +00:00
else :
break
2023-08-09 14:04:35 +00:00
if not decode_success :
# all remaining tokens cannot be decoded to a UTF-8 character
break
token_end_position + = len ( bs )
2023-09-29 02:42:03 +00:00
if token_end_position > (
remaining_length - first_stop_position
) :
2023-08-09 14:04:35 +00:00
break
remaining_tokens = remaining_tokens [ i : ]
returned_tokens + = i
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
2023-08-29 11:21:59 +00:00
" text " : ts ,
2023-08-09 14:04:35 +00:00
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
}
2023-04-12 18:05:11 +00:00
2023-04-02 07:59:19 +00:00
if len ( completion_tokens ) > = max_tokens :
2024-02-23 17:23:24 +00:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-04-02 07:59:19 +00:00
finish_reason = " length "
break
2023-03-23 09:33:06 +00:00
2023-05-26 07:13:24 +00:00
if stopping_criteria is not None and stopping_criteria (
2023-07-18 23:27:41 +00:00
self . _input_ids , self . _scores [ - 1 , : ]
2023-05-26 07:13:24 +00:00
) :
2024-02-23 17:23:24 +00:00
text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
2023-05-26 07:13:24 +00:00
finish_reason = " stop "
2023-05-10 20:12:17 +00:00
if self . verbose :
2023-11-06 14:16:36 +00:00
self . _ctx . print_timings ( )
2023-05-10 20:12:17 +00:00
2023-03-28 08:03:57 +00:00
if stream :
2023-05-18 15:35:59 +00:00
remaining_tokens = completion_tokens [ returned_tokens : ]
2024-02-23 17:23:24 +00:00
all_text = self . detokenize ( remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-18 15:35:59 +00:00
any_stop = [ s for s in stop_sequences if s in all_text ]
if len ( any_stop ) > 0 :
end = min ( all_text . index ( stop ) for stop in any_stop )
else :
end = len ( all_text )
2023-05-19 06:20:27 +00:00
token_end_position = 0
2023-05-18 15:35:59 +00:00
for token in remaining_tokens :
2024-02-23 17:23:24 +00:00
token_end_position + = len ( self . detokenize ( [ token ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] ) )
2023-05-19 06:20:27 +00:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
if logprobs is not None :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2023-05-19 06:20:27 +00:00
token_str = self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
)
text_offset = len ( prompt ) + len (
2024-02-23 17:23:24 +00:00
self . detokenize ( completion_tokens [ : returned_tokens ] , prev_tokens = prompt_tokens + completion_tokens [ : returned_tokens ] )
2023-05-19 06:20:27 +00:00
)
token_offset = len ( prompt_tokens ) + returned_tokens - 1
2023-12-18 19:28:12 +00:00
logits = self . _scores [ token_offset , : ]
2023-12-18 23:40:36 +00:00
current_logprobs = Llama . logits_to_logprobs ( logits ) . tolist ( )
2023-05-19 06:20:27 +00:00
sorted_logprobs = list (
sorted (
zip ( current_logprobs , range ( len ( current_logprobs ) ) ) ,
reverse = True ,
)
)
top_logprob = {
2023-05-19 15:59:33 +00:00
self . detokenize ( [ i ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-05-19 06:20:27 +00:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : current_logprobs [ int ( token ) ] } )
logprobs_or_none = {
" tokens " : [
self . detokenize ( [ token ] ) . decode ( " utf-8 " , errors = " ignore " )
] ,
" text_offset " : [ text_offset ] ,
2023-07-07 10:18:49 +00:00
" token_logprobs " : [ current_logprobs [ int ( token ) ] ] ,
2023-05-19 06:20:27 +00:00
" top_logprobs " : [ top_logprob ] ,
}
if token_end_position > = end :
2023-05-18 15:35:59 +00:00
last_text = self . detokenize ( [ token ] )
2023-05-19 06:20:27 +00:00
if token_end_position == end - 1 :
2023-05-18 15:35:59 +00:00
break
2023-05-19 06:20:27 +00:00
returned_tokens + = 1
2023-05-18 15:35:59 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : last_text [
2023-06-08 17:19:23 +00:00
: len ( last_text ) - ( token_end_position - end )
] . decode ( " utf-8 " , errors = " ignore " ) ,
2023-05-18 15:35:59 +00:00
" index " : 0 ,
2023-05-19 06:20:27 +00:00
" logprobs " : logprobs_or_none ,
2023-07-08 04:06:11 +00:00
" finish_reason " : None ,
}
] ,
}
2023-05-18 15:35:59 +00:00
break
returned_tokens + = 1
2023-03-28 08:03:57 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
2023-05-18 15:35:59 +00:00
" model " : model_name ,
2023-03-28 08:03:57 +00:00
" choices " : [
{
2023-05-18 15:35:59 +00:00
" text " : self . detokenize ( [ token ] ) . decode (
" utf-8 " , errors = " ignore "
) ,
2023-03-28 08:03:57 +00:00
" index " : 0 ,
2023-05-19 06:20:27 +00:00
" logprobs " : logprobs_or_none ,
2023-03-28 08:03:57 +00:00
" finish_reason " : None ,
}
] ,
}
2023-10-19 06:55:56 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : model_name ,
" choices " : [
{
" text " : " " ,
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : finish_reason ,
}
] ,
}
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-26 07:03:01 +00:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-06-08 17:19:23 +00:00
print ( " Llama._create_completion: cache saved " , file = sys . stderr )
2023-03-28 08:03:57 +00:00
return
2023-06-10 16:22:31 +00:00
if self . cache :
2023-05-26 07:03:01 +00:00
if self . verbose :
print ( " Llama._create_completion: cache save " , file = sys . stderr )
self . cache [ prompt_tokens + completion_tokens ] = self . save_state ( )
2023-04-26 12:37:06 +00:00
text_str = text . decode ( " utf-8 " , errors = " ignore " )
2023-03-23 20:25:13 +00:00
2023-03-23 09:33:06 +00:00
if echo :
2023-04-15 16:03:09 +00:00
text_str = prompt + text_str
2023-03-23 09:33:06 +00:00
2024-05-08 06:26:22 +00:00
if suffix_token_id < 0 and suffix is not None :
2023-04-15 16:03:09 +00:00
text_str = text_str + suffix
2023-03-23 09:33:06 +00:00
2023-04-12 18:05:11 +00:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
2023-03-23 19:51:05 +00:00
if logprobs is not None :
2023-05-19 06:20:27 +00:00
text_offset = 0 if echo else len ( prompt )
token_offset = 0 if echo else len ( prompt_tokens [ 1 : ] )
2023-04-14 13:59:33 +00:00
text_offsets : List [ int ] = [ ]
2023-05-19 06:20:27 +00:00
token_logprobs : List [ Optional [ float ] ] = [ ]
2023-04-14 13:59:33 +00:00
tokens : List [ str ] = [ ]
2023-05-19 06:20:27 +00:00
top_logprobs : List [ Optional [ Dict [ str , float ] ] ] = [ ]
if echo :
# Remove leading BOS token
all_tokens = prompt_tokens [ 1 : ] + completion_tokens
else :
all_tokens = completion_tokens
2023-04-14 13:59:33 +00:00
all_token_strs = [
2024-02-23 17:23:24 +00:00
self . detokenize ( [ token ] , prev_tokens = all_tokens [ : i ] ) . decode ( " utf-8 " , errors = " ignore " )
for i , token in enumerate ( all_tokens )
2023-04-14 13:59:33 +00:00
]
2023-12-18 19:28:12 +00:00
all_logprobs = Llama . logits_to_logprobs ( self . _scores ) [ token_offset : ]
# TODO: may be able to change this loop to use np.take_along_dim
2023-12-22 05:03:29 +00:00
for idx , ( token , token_str , logprobs_token ) in enumerate (
zip ( all_tokens , all_token_strs , all_logprobs )
2023-04-14 13:59:33 +00:00
) :
2023-11-21 03:50:59 +00:00
if token == self . token_bos ( ) :
continue
2023-12-22 05:03:29 +00:00
text_offsets . append (
text_offset
+ len (
self . detokenize ( all_tokens [ : idx ] ) . decode (
" utf-8 " , errors = " ignore "
)
)
)
2023-04-14 13:59:33 +00:00
tokens . append ( token_str )
sorted_logprobs = list (
sorted (
zip ( logprobs_token , range ( len ( logprobs_token ) ) ) , reverse = True
)
)
2023-07-07 10:18:49 +00:00
token_logprobs . append ( logprobs_token [ int ( token ) ] )
2023-05-19 06:20:27 +00:00
top_logprob : Optional [ Dict [ str , float ] ] = {
2024-02-23 17:23:24 +00:00
self . detokenize ( [ i ] , prev_tokens = all_tokens [ : idx ] ) . decode ( " utf-8 " , errors = " ignore " ) : logprob
2023-04-14 13:59:33 +00:00
for logprob , i in sorted_logprobs [ : logprobs ]
}
2023-05-19 06:20:27 +00:00
top_logprob . update ( { token_str : logprobs_token [ int ( token ) ] } )
2023-04-14 13:59:33 +00:00
top_logprobs . append ( top_logprob )
2023-05-19 06:20:27 +00:00
# Weird idosincracy of the OpenAI API where
# token_logprobs and top_logprobs are null for
# the first token.
if echo and len ( all_tokens ) > 0 :
token_logprobs [ 0 ] = None
top_logprobs [ 0 ] = None
2023-04-12 18:05:11 +00:00
logprobs_or_none = {
" tokens " : tokens ,
" text_offset " : text_offsets ,
" token_logprobs " : token_logprobs ,
" top_logprobs " : top_logprobs ,
}
2023-04-04 17:09:24 +00:00
2023-03-28 08:03:57 +00:00
yield {
2023-03-28 06:42:22 +00:00
" id " : completion_id ,
2023-03-23 09:33:06 +00:00
" object " : " text_completion " ,
2023-03-28 06:42:22 +00:00
" created " : created ,
2023-05-16 22:07:25 +00:00
" model " : model_name ,
2023-03-23 09:33:06 +00:00
" choices " : [
{
2023-04-15 16:03:09 +00:00
" text " : text_str ,
2023-03-23 09:33:06 +00:00
" index " : 0 ,
2023-04-12 18:05:11 +00:00
" logprobs " : logprobs_or_none ,
2023-03-23 09:33:06 +00:00
" finish_reason " : finish_reason ,
}
] ,
" usage " : {
2023-03-28 05:45:37 +00:00
" prompt_tokens " : len ( prompt_tokens ) ,
" completion_tokens " : len ( completion_tokens ) ,
" total_tokens " : len ( prompt_tokens ) + len ( completion_tokens ) ,
2023-03-23 09:33:06 +00:00
} ,
}
2023-04-01 17:01:27 +00:00
def create_completion (
self ,
2023-11-08 03:48:51 +00:00
prompt : Union [ str , List [ int ] ] ,
2023-04-01 17:01:27 +00:00
suffix : Optional [ str ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-04-01 17:01:27 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-01 17:01:27 +00:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-04-01 17:01:27 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-04-01 17:01:27 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-10 07:49:27 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-04-01 17:01:27 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-01 17:01:27 +00:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-04-01 17:01:27 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-04-01 17:01:27 +00:00
stream : Whether to stream the results .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-04-01 17:01:27 +00:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
completion_or_chunks = self . _create_completion (
prompt = prompt ,
suffix = suffix ,
2023-12-22 19:05:13 +00:00
max_tokens = - 1 if max_tokens is None else max_tokens ,
2023-04-01 17:01:27 +00:00
temperature = temperature ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-04-01 17:01:27 +00:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 05:30:18 +00:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-04-01 17:01:27 +00:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-05-26 07:13:24 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-09-29 02:42:03 +00:00
grammar = grammar ,
2023-11-21 09:01:36 +00:00
logit_bias = logit_bias ,
2023-04-01 17:01:27 +00:00
)
if stream :
2023-11-08 03:48:51 +00:00
chunks : Iterator [ CreateCompletionStreamResponse ] = completion_or_chunks
2023-04-01 17:01:27 +00:00
return chunks
completion : Completion = next ( completion_or_chunks ) # type: ignore
return completion
2023-03-28 08:03:57 +00:00
def __call__ (
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
2023-12-22 19:05:13 +00:00
max_tokens : Optional [ int ] = 16 ,
2023-03-28 08:03:57 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-03-28 08:03:57 +00:00
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
frequency_penalty : float = 0.0 ,
presence_penalty : float = 0.0 ,
2023-03-28 08:03:57 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 09:01:36 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [ CreateCompletionResponse , Iterator [ CreateCompletionStreamResponse ] ] :
2023-03-28 08:03:57 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
2023-11-24 08:24:19 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-03-28 08:03:57 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-03-28 08:03:57 +00:00
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
2023-03-28 08:03:57 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
2023-03-28 08:03:57 +00:00
stream : Whether to stream the results .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
tfs_z : The tail - free sampling parameter . Tail Free Sampling described in https : / / www . trentonbricken . com / Tail - Free - Sampling / .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The target cross - entropy ( or surprise ) value you want to achieve for the generated text . A higher value corresponds to more surprising or less predictable text , while a lower value corresponds to less surprising or more predictable text .
mirostat_eta : The learning rate used to update ` mu ` based on the error between the target and observed surprisal of the sampled word . A larger learning rate will cause ` mu ` to be updated more quickly , while a smaller learning rate will result in slower updates .
model : The name to use for the model in the completion object .
stopping_criteria : A list of stopping criteria to use .
logits_processor : A list of logits processors to use .
grammar : A grammar to use for constrained sampling .
logit_bias : A logit bias to use .
2023-03-28 08:03:57 +00:00
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
2023-04-01 17:01:27 +00:00
return self . create_completion (
2023-03-28 08:03:57 +00:00
prompt = prompt ,
suffix = suffix ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-03-28 08:03:57 +00:00
logprobs = logprobs ,
echo = echo ,
stop = stop ,
2023-05-08 05:30:18 +00:00
frequency_penalty = frequency_penalty ,
presence_penalty = presence_penalty ,
2023-03-28 08:03:57 +00:00
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-05-26 07:13:24 +00:00
stopping_criteria = stopping_criteria ,
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-11-21 09:01:36 +00:00
logit_bias = logit_bias ,
2023-03-28 08:03:57 +00:00
)
2023-04-04 00:12:44 +00:00
def create_chat_completion (
self ,
2023-09-29 23:52:04 +00:00
messages : List [ ChatCompletionRequestMessage ] ,
2023-07-19 07:48:20 +00:00
functions : Optional [ List [ ChatCompletionFunction ] ] = None ,
2023-11-08 03:48:51 +00:00
function_call : Optional [ ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ ChatCompletionTool ] ] = None ,
tool_choice : Optional [ ChatCompletionToolChoiceOption ] = None ,
2023-06-08 17:19:23 +00:00
temperature : float = 0.2 ,
2023-04-04 00:12:44 +00:00
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-04-04 00:12:44 +00:00
stream : bool = False ,
2023-06-08 17:19:23 +00:00
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-08 04:37:28 +00:00
seed : Optional [ int ] = None ,
2023-11-08 05:07:16 +00:00
response_format : Optional [ ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 07:49:27 +00:00
max_tokens : Optional [ int ] = None ,
2023-06-08 17:19:23 +00:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
2023-04-04 00:12:44 +00:00
repeat_penalty : float = 1.1 ,
2023-06-08 17:19:23 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
2023-06-09 17:13:08 +00:00
logits_processor : Optional [ LogitsProcessorList ] = None ,
2023-08-08 19:08:54 +00:00
grammar : Optional [ LlamaGrammar ] = None ,
2023-11-21 08:59:46 +00:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2024-02-23 17:23:24 +00:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 03:48:51 +00:00
) - > Union [
CreateChatCompletionResponse , Iterator [ CreateChatCompletionStreamResponse ]
] :
2023-04-04 00:24:20 +00:00
""" Generate a chat completion from a list of messages.
Args :
messages : A list of messages to generate a response for .
2023-11-24 08:24:19 +00:00
functions : A list of functions to use for the chat completion .
function_call : A function call to use for the chat completion .
tools : A list of tools to use for the chat completion .
tool_choice : A tool choice to use for the chat completion .
2023-04-04 00:24:20 +00:00
temperature : The temperature to use for sampling .
2023-11-24 08:24:19 +00:00
top_p : The top - p value to use for nucleus sampling . Nucleus sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
top_k : The top - k value to use for sampling . Top - K sampling described in academic paper " The Curious Case of Neural Text Degeneration " https : / / arxiv . org / abs / 1904.09751
min_p : The min - p value to use for minimum p sampling . Minimum P sampling as described in https : / / github . com / ggerganov / llama . cpp / pull / 3841
typical_p : The typical - p value to use for sampling . Locally Typical Sampling implementation described in the paper https : / / arxiv . org / abs / 2202.00666 .
2023-04-04 00:24:20 +00:00
stream : Whether to stream the results .
stop : A list of strings to stop generation when encountered .
2023-11-24 08:24:19 +00:00
seed : The seed to use for sampling .
response_format : The response format to use for the chat completion . Use { " type " : " json_object " } to contstrain output to only valid json .
2023-11-10 07:49:27 +00:00
max_tokens : The maximum number of tokens to generate . If max_tokens < = 0 or None , the maximum number of tokens to generate is unlimited and depends on n_ctx .
2023-11-24 08:24:19 +00:00
presence_penalty : The penalty to apply to tokens based on their presence in the prompt .
frequency_penalty : The penalty to apply to tokens based on their frequency in the prompt .
2023-04-04 00:24:20 +00:00
repeat_penalty : The penalty to apply to repeated tokens .
2023-11-24 08:24:19 +00:00
tfs_z : The tail - free sampling parameter .
mirostat_mode : The mirostat sampling mode .
mirostat_tau : The mirostat sampling tau parameter .
mirostat_eta : The mirostat sampling eta parameter .
model : The name to use for the model in the completion object .
logits_processor : A list of logits processors to use .
grammar : A grammar to use .
logit_bias : A logit bias to use .
2023-04-04 00:24:20 +00:00
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
2024-05-09 13:49:09 +00:00
handler = self . chat_handler or self . _chat_handlers . get ( self . chat_format ) or llama_chat_format . get_chat_completion_handler (
2023-11-08 03:48:51 +00:00
self . chat_format
)
2023-11-03 06:12:14 +00:00
return handler (
2023-11-08 03:48:51 +00:00
llama = self ,
2023-09-29 23:52:04 +00:00
messages = messages ,
2023-11-03 06:12:14 +00:00
functions = functions ,
function_call = function_call ,
2023-11-08 03:48:51 +00:00
tools = tools ,
tool_choice = tool_choice ,
2023-04-04 00:12:44 +00:00
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2024-04-10 07:41:55 +00:00
logprobs = logprobs ,
top_logprobs = top_logprobs ,
2023-04-04 00:12:44 +00:00
stream = stream ,
2023-09-29 23:52:04 +00:00
stop = stop ,
2023-11-08 04:37:28 +00:00
seed = seed ,
2023-11-08 05:07:16 +00:00
response_format = response_format ,
2023-04-04 00:12:44 +00:00
max_tokens = max_tokens ,
2023-05-08 05:30:18 +00:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
2023-09-29 23:52:04 +00:00
repeat_penalty = repeat_penalty ,
2023-05-12 01:56:19 +00:00
tfs_z = tfs_z ,
2023-05-09 01:21:25 +00:00
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
2023-05-16 21:22:00 +00:00
model = model ,
2023-06-09 17:13:08 +00:00
logits_processor = logits_processor ,
2023-08-08 19:08:54 +00:00
grammar = grammar ,
2023-11-21 08:59:46 +00:00
logit_bias = logit_bias ,
2023-04-04 00:12:44 +00:00
)
2024-02-12 20:56:07 +00:00
def create_chat_completion_openai_v1 (
self ,
* args : Any ,
* * kwargs : Any ,
) :
""" Generate a chat completion with return type based on the the OpenAI v1 API.
OpenAI python package is required to use this method .
You can install it with ` pip install openai ` .
Args :
* args : Positional arguments to pass to create_chat_completion .
* * kwargs : Keyword arguments to pass to create_chat_completion .
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
try :
from openai . types . chat import ChatCompletion , ChatCompletionChunk
2024-02-21 21:25:10 +00:00
stream = kwargs . get ( " stream " , False ) # type: ignore
2024-02-12 20:56:07 +00:00
assert isinstance ( stream , bool )
if stream :
2024-02-21 21:25:10 +00:00
return ( ChatCompletionChunk ( * * chunk ) for chunk in self . create_chat_completion ( * args , * * kwargs ) ) # type: ignore
2024-02-12 20:56:07 +00:00
else :
2024-02-21 21:25:10 +00:00
return ChatCompletion ( * * self . create_chat_completion ( * args , * * kwargs ) ) # type: ignore
2024-02-12 20:56:07 +00:00
except ImportError :
raise ImportError (
" To use create_chat_completion_openai_v1, you must install the openai package. "
" You can install it with `pip install openai`. "
)
2023-04-05 10:52:17 +00:00
def __getstate__ ( self ) :
return dict (
model_path = self . model_path ,
2023-09-29 02:42:03 +00:00
# Model Params
n_gpu_layers = self . model_params . n_gpu_layers ,
2024-01-15 17:49:20 +00:00
split_mode = self . model_params . split_mode ,
2023-09-29 02:42:03 +00:00
main_gpu = self . model_params . main_gpu ,
tensor_split = self . tensor_split ,
vocab_only = self . model_params . vocab_only ,
use_mmap = self . model_params . use_mmap ,
use_mlock = self . model_params . use_mlock ,
2024-01-15 17:29:29 +00:00
kv_overrides = self . kv_overrides ,
2023-09-29 02:42:03 +00:00
# Context Params
seed = self . context_params . seed ,
n_ctx = self . context_params . n_ctx ,
2023-04-05 10:52:17 +00:00
n_batch = self . n_batch ,
2023-09-29 02:42:03 +00:00
n_threads = self . context_params . n_threads ,
n_threads_batch = self . context_params . n_threads_batch ,
2023-11-02 17:40:20 +00:00
rope_scaling_type = self . context_params . rope_scaling_type ,
2024-04-01 14:19:28 +00:00
pooling_type = self . context_params . pooling_type ,
2023-09-29 02:42:03 +00:00
rope_freq_base = self . context_params . rope_freq_base ,
rope_freq_scale = self . context_params . rope_freq_scale ,
2023-11-02 17:40:20 +00:00
yarn_ext_factor = self . context_params . yarn_ext_factor ,
yarn_attn_factor = self . context_params . yarn_attn_factor ,
yarn_beta_fast = self . context_params . yarn_beta_fast ,
yarn_beta_slow = self . context_params . yarn_beta_slow ,
yarn_orig_ctx = self . context_params . yarn_orig_ctx ,
2023-09-29 02:42:03 +00:00
logits_all = self . context_params . logits_all ,
2024-03-06 06:32:00 +00:00
embedding = self . context_params . embeddings ,
2024-04-01 14:19:28 +00:00
offload_kqv = self . context_params . offload_kqv ,
2024-04-30 13:32:47 +00:00
flash_attn = self . context_params . flash_attn ,
2023-09-29 02:42:03 +00:00
# Sampling Params
last_n_tokens_size = self . last_n_tokens_size ,
# LoRA Params
2023-04-18 14:20:46 +00:00
lora_base = self . lora_base ,
2023-09-29 02:42:03 +00:00
lora_scale = self . lora_scale ,
2023-04-18 05:43:44 +00:00
lora_path = self . lora_path ,
2023-09-29 02:42:03 +00:00
# Backend Params
numa = self . numa ,
2023-09-29 23:52:04 +00:00
# Chat Format Params
chat_format = self . chat_format ,
2023-11-08 03:48:51 +00:00
chat_handler = self . chat_handler ,
2024-04-01 14:19:28 +00:00
# Speculative Decidng
draft_model = self . draft_model ,
# KV cache quantization
type_k = self . context_params . type_k ,
type_v = self . context_params . type_v ,
2023-09-29 02:42:03 +00:00
# Misc
verbose = self . verbose ,
2023-04-05 10:52:17 +00:00
)
def __setstate__ ( self , state ) :
2024-04-01 14:19:28 +00:00
self . __init__ ( * * state )
2023-04-05 10:52:17 +00:00
2023-04-24 21:51:25 +00:00
def save_state ( self ) - > LlamaState :
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( " Llama.save_state: saving llama state " , file = sys . stderr )
2023-11-06 14:16:36 +00:00
state_size = llama_cpp . llama_get_state_size ( self . _ctx . ctx )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( f " Llama.save_state: got state size: { state_size } " , file = sys . stderr )
2024-02-23 16:24:53 +00:00
llama_state = ( ctypes . c_uint8 * int ( state_size ) ) ( )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( " Llama.save_state: allocated state " , file = sys . stderr )
2023-11-06 14:16:36 +00:00
n_bytes = llama_cpp . llama_copy_state_data ( self . _ctx . ctx , llama_state )
2023-06-08 17:19:23 +00:00
if self . verbose :
print ( f " Llama.save_state: copied llama state: { n_bytes } " , file = sys . stderr )
2023-05-03 13:33:50 +00:00
if int ( n_bytes ) > int ( state_size ) :
2023-04-24 21:51:25 +00:00
raise RuntimeError ( " Failed to copy llama state data " )
2024-02-23 16:24:53 +00:00
llama_state_compact = ( ctypes . c_uint8 * int ( n_bytes ) ) ( )
2023-05-03 13:33:50 +00:00
llama_cpp . ctypes . memmove ( llama_state_compact , llama_state , int ( n_bytes ) )
2023-05-03 14:28:10 +00:00
if self . verbose :
2023-05-05 01:58:36 +00:00
print (
f " Llama.save_state: saving { n_bytes } bytes of llama state " ,
file = sys . stderr ,
)
2023-04-24 21:51:25 +00:00
return LlamaState (
2024-04-17 14:06:50 +00:00
scores = self . _scores . copy ( ) ,
2023-06-29 04:40:47 +00:00
input_ids = self . input_ids . copy ( ) ,
n_tokens = self . n_tokens ,
2023-06-13 10:03:31 +00:00
llama_state = bytes ( llama_state_compact ) ,
2023-05-03 13:33:50 +00:00
llama_state_size = n_bytes ,
2023-04-24 21:51:25 +00:00
)
def load_state ( self , state : LlamaState ) - > None :
2023-11-06 14:16:36 +00:00
assert self . _ctx . ctx is not None
2024-04-17 14:06:50 +00:00
# Only filling in up to `n_tokens` and then zero-ing out the rest
self . scores [ : state . n_tokens , : ] = state . scores . copy ( )
self . scores [ state . n_tokens : , : ] = 0.0
2023-06-29 04:40:47 +00:00
self . input_ids = state . input_ids . copy ( )
self . n_tokens = state . n_tokens
2023-05-03 13:33:50 +00:00
state_size = state . llama_state_size
2024-02-21 21:25:38 +00:00
LLamaStateArrayType = ctypes . c_uint8 * state_size
2023-06-13 10:03:31 +00:00
llama_state = LLamaStateArrayType . from_buffer_copy ( state . llama_state )
2023-11-06 14:16:36 +00:00
if llama_cpp . llama_set_state_data ( self . _ctx . ctx , llama_state ) != state_size :
2023-04-24 21:51:25 +00:00
raise RuntimeError ( " Failed to set llama state data " )
2023-05-20 12:13:41 +00:00
def n_ctx ( self ) - > int :
""" Return the context window size. """
2023-11-06 14:16:36 +00:00
return self . _ctx . n_ctx ( )
2023-05-20 12:13:41 +00:00
def n_embd ( self ) - > int :
""" Return the embedding size. """
2023-11-06 14:16:36 +00:00
return self . _model . n_embd ( )
2023-05-20 12:13:41 +00:00
def n_vocab ( self ) - > int :
""" Return the vocabulary size. """
2023-11-06 14:16:36 +00:00
return self . _model . n_vocab ( )
2023-05-20 12:13:41 +00:00
2024-02-08 01:07:03 +00:00
def tokenizer ( self ) - > LlamaTokenizer :
""" Return the llama tokenizer for this model. """
2023-05-25 18:11:33 +00:00
return LlamaTokenizer ( self )
2023-04-05 10:52:17 +00:00
2023-08-24 04:17:00 +00:00
def token_eos ( self ) - > int :
2023-04-01 21:29:30 +00:00
""" Return the end-of-sequence token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_eos ( )
2023-04-01 21:29:30 +00:00
2023-08-24 04:17:00 +00:00
def token_bos ( self ) - > int :
2023-04-01 21:29:30 +00:00
""" Return the beginning-of-sequence token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_bos ( )
2023-04-12 18:05:11 +00:00
2023-08-24 04:17:00 +00:00
def token_nl ( self ) - > int :
2023-05-17 05:53:26 +00:00
""" Return the newline token. """
2023-11-06 14:16:36 +00:00
return self . _model . token_nl ( )
2023-05-17 05:53:26 +00:00
2024-04-26 01:32:44 +00:00
def pooling_type ( self ) - > str :
""" Return the pooling type. """
return self . _ctx . pooling_type ( )
2023-04-12 18:05:11 +00:00
@staticmethod
2023-12-16 23:59:26 +00:00
def logits_to_logprobs (
2023-12-18 19:28:12 +00:00
logits : Union [ npt . NDArray [ np . single ] , List ] , axis : int = - 1
2023-12-16 23:59:26 +00:00
) - > npt . NDArray [ np . single ] :
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
logits_maxs : np . ndarray = np . amax ( logits , axis = axis , keepdims = True )
if logits_maxs . ndim > 0 :
logits_maxs [ ~ np . isfinite ( logits_maxs ) ] = 0
elif not np . isfinite ( logits_maxs ) :
logits_maxs = 0
subtract_maxs = np . subtract ( logits , logits_maxs , dtype = np . single )
exp = np . exp ( subtract_maxs )
# Suppress warnings about log of zero
2023-12-18 19:28:12 +00:00
with np . errstate ( divide = " ignore " ) :
2023-12-16 23:59:26 +00:00
summed = np . sum ( exp , axis = axis , keepdims = True )
out = np . log ( summed )
return subtract_maxs - out
2023-05-07 23:31:26 +00:00
@staticmethod
2023-05-19 15:59:33 +00:00
def longest_token_prefix ( a : Sequence [ int ] , b : Sequence [ int ] ) :
2023-05-07 23:31:26 +00:00
longest_prefix = 0
for _a , _b in zip ( a , b ) :
if _a == _b :
longest_prefix + = 1
else :
break
return longest_prefix
2023-05-25 18:11:33 +00:00
2024-02-21 21:25:10 +00:00
@classmethod
def from_pretrained (
cls ,
repo_id : str ,
filename : Optional [ str ] ,
2024-02-22 05:10:23 +00:00
local_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
2024-02-21 21:25:10 +00:00
local_dir_use_symlinks : Union [ bool , Literal [ " auto " ] ] = " auto " ,
2024-02-22 05:10:23 +00:00
cache_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
2024-02-21 21:25:10 +00:00
* * kwargs : Any ,
) - > " Llama " :
""" Create a Llama model from a pretrained model name or path.
This method requires the huggingface - hub package .
You can install it with ` pip install huggingface - hub ` .
Args :
repo_id : The model repo id .
filename : A filename or glob pattern to match the model file in the repo .
local_dir : The local directory to save the model to .
local_dir_use_symlinks : Whether to use symlinks when downloading the model .
* * kwargs : Additional keyword arguments to pass to the Llama constructor .
Returns :
A Llama model . """
try :
from huggingface_hub import hf_hub_download , HfFileSystem
from huggingface_hub . utils import validate_repo_id
except ImportError :
raise ImportError (
" Llama.from_pretrained requires the huggingface-hub package. "
" You can install it with `pip install huggingface-hub`. "
)
validate_repo_id ( repo_id )
hffs = HfFileSystem ( )
files = [
file [ " name " ] if isinstance ( file , dict ) else file
for file in hffs . ls ( repo_id )
]
# split each file into repo_id, subfolder, filename
file_list : List [ str ] = [ ]
for file in files :
rel_path = Path ( file ) . relative_to ( repo_id )
file_list . append ( str ( rel_path ) )
2023-05-25 18:11:33 +00:00
2024-02-21 21:25:10 +00:00
matching_files = [ file for file in file_list if fnmatch . fnmatch ( file , filename ) ] # type: ignore
if len ( matching_files ) == 0 :
raise ValueError (
f " No file found in { repo_id } that match { filename } \n \n "
f " Available Files: \n { json . dumps ( file_list ) } "
)
if len ( matching_files ) > 1 :
raise ValueError (
f " Multiple files found in { repo_id } matching { filename } \n \n "
f " Available Files: \n { json . dumps ( files ) } "
)
( matching_file , ) = matching_files
subfolder = str ( Path ( matching_file ) . parent )
filename = Path ( matching_file ) . name
# download the file
hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
2024-02-22 05:10:23 +00:00
local_dir = local_dir ,
2024-02-21 21:25:10 +00:00
local_dir_use_symlinks = local_dir_use_symlinks ,
2024-02-22 05:10:23 +00:00
cache_dir = cache_dir ,
2024-02-21 21:25:10 +00:00
)
2024-02-22 05:10:23 +00:00
if local_dir is None :
model_path = hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = local_dir ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cache_dir ,
local_files_only = True ,
)
else :
model_path = os . path . join ( local_dir , filename )
2024-02-21 21:25:10 +00:00
return cls (
model_path = model_path ,
* * kwargs ,
)
2024-02-08 01:07:03 +00:00
2024-01-17 14:16:13 +00:00
class LlamaState :
def __init__ (
self ,
input_ids : npt . NDArray [ np . intc ] ,
scores : npt . NDArray [ np . single ] ,
n_tokens : int ,
llama_state : bytes ,
llama_state_size : int ,
) :
self . input_ids = input_ids
self . scores = scores
self . n_tokens = n_tokens
self . llama_state = llama_state
self . llama_state_size = llama_state_size
LogitsProcessor = Callable [
[ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , npt . NDArray [ np . single ]
]
class LogitsProcessorList ( List [ LogitsProcessor ] ) :
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , scores : npt . NDArray [ np . single ]
) - > npt . NDArray [ np . single ] :
for processor in self :
scores = processor ( input_ids , scores )
return scores
StoppingCriteria = Callable [ [ npt . NDArray [ np . intc ] , npt . NDArray [ np . single ] ] , bool ]
class StoppingCriteriaList ( List [ StoppingCriteria ] ) :
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , logits : npt . NDArray [ np . single ]
) - > bool :
return any ( [ stopping_criteria ( input_ids , logits ) for stopping_criteria in self ] )
2024-05-14 13:50:53 +00:00
class MinTokensLogitsProcessor ( LogitsProcessor ) :
def __init__ ( self , min_tokens : int , token_eos : int ) :
self . min_tokens = min_tokens
self . token_eos = token_eos
self . prompt_tokens = None
def __call__ (
self , input_ids : npt . NDArray [ np . intc ] , scores : npt . NDArray [ np . single ]
) - > npt . NDArray [ np . single ] :
if self . prompt_tokens is None :
self . prompt_tokens = len ( input_ids )
if len ( input_ids ) - self . prompt_tokens < self . min_tokens :
scores [ self . token_eos ] = - np . inf
return scores