2023-03-24 19:47:17 +00:00
import os
2023-04-04 17:09:24 +00:00
import sys
2023-03-23 09:33:06 +00:00
import uuid
import time
2023-04-12 18:05:11 +00:00
import math
2023-03-23 09:33:06 +00:00
import multiprocessing
2023-04-04 00:12:14 +00:00
from typing import List , Optional , Union , Generator , Sequence , Iterator
2023-03-28 05:45:37 +00:00
from collections import deque
2023-03-23 09:33:06 +00:00
from . import llama_cpp
2023-04-01 17:01:27 +00:00
from . llama_types import *
2023-03-23 09:33:06 +00:00
2023-03-24 18:35:41 +00:00
2023-03-23 09:33:06 +00:00
class Llama :
2023-03-24 22:57:59 +00:00
""" High-level Python wrapper for a llama.cpp model. """
2023-03-23 09:33:06 +00:00
def __init__ (
self ,
model_path : str ,
2023-04-01 17:01:27 +00:00
# NOTE: These parameters are likely to change in the future.
2023-03-23 09:33:06 +00:00
n_ctx : int = 512 ,
n_parts : int = - 1 ,
seed : int = 1337 ,
f16_kv : bool = False ,
logits_all : bool = False ,
vocab_only : bool = False ,
2023-04-10 06:11:35 +00:00
use_mmap : bool = True ,
2023-03-25 20:26:23 +00:00
use_mlock : bool = False ,
embedding : bool = False ,
2023-03-23 09:33:06 +00:00
n_threads : Optional [ int ] = None ,
2023-04-01 17:01:27 +00:00
n_batch : int = 8 ,
last_n_tokens_size : int = 64 ,
2023-04-04 17:09:24 +00:00
verbose : bool = True ,
2023-04-01 17:01:27 +00:00
) :
2023-03-24 22:57:59 +00:00
""" Load a llama.cpp model from `model_path`.
Args :
2023-03-25 16:33:18 +00:00
model_path : Path to the model .
n_ctx : Maximum context size .
2023-03-24 22:57:59 +00:00
n_parts : Number of parts to split the model into . If - 1 , the number of parts is automatically determined .
2023-03-25 16:33:18 +00:00
seed : Random seed . 0 for random .
f16_kv : Use half - precision for key / value cache .
logits_all : Return logits for all tokens , not just the last token .
vocab_only : Only load the vocabulary no weights .
2023-04-10 06:11:35 +00:00
use_mmap : Use mmap if possible .
2023-03-25 20:26:23 +00:00
use_mlock : Force the system to keep the model in RAM .
embedding : Embedding mode only .
2023-03-24 22:57:59 +00:00
n_threads : Number of threads to use . If None , the number of threads is automatically determined .
2023-04-01 17:01:27 +00:00
n_batch : Maximum number of prompt tokens to batch together when calling llama_eval .
last_n_tokens_size : Maximum number of tokens to keep in the last_n_tokens deque .
2023-04-04 17:09:24 +00:00
verbose : Print verbose output to stderr .
2023-03-24 22:57:59 +00:00
Raises :
ValueError : If the model path does not exist .
Returns :
A Llama instance .
"""
2023-04-04 17:09:24 +00:00
self . verbose = verbose
2023-03-23 09:33:06 +00:00
self . model_path = model_path
self . params = llama_cpp . llama_context_default_params ( )
self . params . n_ctx = n_ctx
self . params . n_parts = n_parts
self . params . seed = seed
self . params . f16_kv = f16_kv
self . params . logits_all = logits_all
self . params . vocab_only = vocab_only
2023-04-10 06:11:35 +00:00
self . params . use_mmap = use_mmap
2023-03-25 20:26:23 +00:00
self . params . use_mlock = use_mlock
self . params . embedding = embedding
2023-03-23 09:33:06 +00:00
2023-04-01 17:01:27 +00:00
self . last_n_tokens_size = last_n_tokens_size
2023-04-02 04:02:47 +00:00
self . last_n_tokens_data = deque (
[ llama_cpp . llama_token ( 0 ) ] * self . last_n_tokens_size ,
maxlen = self . last_n_tokens_size ,
)
self . tokens_consumed = 0
2023-04-04 17:08:21 +00:00
self . n_batch = min ( n_ctx , n_batch )
2023-04-12 18:05:11 +00:00
self . n_tokens = 0
self . n_past = 0
self . all_logits : List [ List [ float ] ] = [ ] # TODO: Use an array instead of a list.
2023-03-23 09:33:06 +00:00
2023-04-08 23:54:04 +00:00
self . n_threads = n_threads or max ( multiprocessing . cpu_count ( ) / / 2 , 1 )
2023-03-23 09:33:06 +00:00
2023-03-24 19:47:17 +00:00
if not os . path . exists ( model_path ) :
raise ValueError ( f " Model path does not exist: { model_path } " )
2023-03-23 09:33:06 +00:00
self . ctx = llama_cpp . llama_init_from_file (
self . model_path . encode ( " utf-8 " ) , self . params
)
2023-04-04 17:09:24 +00:00
if self . verbose :
print ( llama_cpp . llama_print_system_info ( ) . decode ( " utf-8 " ) , file = sys . stderr )
2023-04-01 17:01:27 +00:00
def tokenize ( self , text : bytes ) - > List [ llama_cpp . llama_token ] :
2023-03-28 05:45:37 +00:00
""" Tokenize a string.
Args :
text : The utf - 8 encoded string to tokenize .
2023-04-01 17:01:27 +00:00
Raises :
RuntimeError : If the tokenization failed .
2023-03-28 05:45:37 +00:00
Returns :
A list of tokens .
"""
2023-04-01 17:01:27 +00:00
assert self . ctx is not None
2023-03-28 05:45:37 +00:00
n_ctx = llama_cpp . llama_n_ctx ( self . ctx )
2023-04-01 17:01:27 +00:00
tokens = ( llama_cpp . llama_token * int ( n_ctx ) ) ( )
2023-03-28 05:45:37 +00:00
n_tokens = llama_cpp . llama_tokenize (
self . ctx ,
text ,
tokens ,
n_ctx ,
2023-04-01 17:01:27 +00:00
llama_cpp . c_bool ( True ) ,
2023-03-28 05:45:37 +00:00
)
2023-04-01 17:01:27 +00:00
if int ( n_tokens ) < 0 :
2023-03-28 08:03:57 +00:00
raise RuntimeError ( f ' Failed to tokenize: text= " { text } " n_tokens= { n_tokens } ' )
2023-03-28 05:45:37 +00:00
return list ( tokens [ : n_tokens ] )
2023-04-01 17:01:27 +00:00
def detokenize ( self , tokens : List [ llama_cpp . llama_token ] ) - > bytes :
2023-03-28 05:45:37 +00:00
""" Detokenize a list of tokens.
Args :
tokens : The list of tokens to detokenize .
Returns :
The detokenized string .
"""
2023-04-01 17:01:27 +00:00
assert self . ctx is not None
2023-03-28 05:45:37 +00:00
output = b " "
for token in tokens :
output + = llama_cpp . llama_token_to_str ( self . ctx , token )
return output
2023-04-02 04:02:47 +00:00
def reset ( self ) :
""" Reset the model state. """
self . last_n_tokens_data . extend (
[ llama_cpp . llama_token ( 0 ) ] * self . last_n_tokens_size
)
self . tokens_consumed = 0
2023-04-12 18:05:11 +00:00
self . n_tokens = 0
self . n_past = 0
self . all_logits = [ ]
2023-04-02 04:02:47 +00:00
def eval ( self , tokens : Sequence [ llama_cpp . llama_token ] ) :
""" Evaluate a list of tokens.
Args :
tokens : The list of tokens to evaluate .
"""
assert self . ctx is not None
n_ctx = int ( llama_cpp . llama_n_ctx ( self . ctx ) )
for i in range ( 0 , len ( tokens ) , self . n_batch ) :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
2023-04-12 18:05:11 +00:00
self . n_past = min ( n_ctx - len ( batch ) , self . tokens_consumed )
self . n_tokens = len ( batch )
2023-04-02 04:02:47 +00:00
return_code = llama_cpp . llama_eval (
ctx = self . ctx ,
tokens = ( llama_cpp . llama_token * len ( batch ) ) ( * batch ) ,
2023-04-12 18:05:11 +00:00
n_tokens = llama_cpp . c_int ( self . n_tokens ) ,
n_past = llama_cpp . c_int ( self . n_past ) ,
2023-04-02 04:02:47 +00:00
n_threads = llama_cpp . c_int ( self . n_threads ) ,
)
if int ( return_code ) != 0 :
raise RuntimeError ( f " llama_eval returned { return_code } " )
self . last_n_tokens_data . extend ( batch )
self . tokens_consumed + = len ( batch )
2023-04-12 18:05:11 +00:00
if self . params . logits_all :
self . all_logits . extend ( self . _logits ( ) )
def _logits ( self ) - > List [ List [ float ] ] :
""" Return the logits from the last call to llama_eval. """
assert self . ctx is not None
n_vocab = llama_cpp . llama_n_vocab ( self . ctx )
cols = int ( n_vocab )
rows = self . n_tokens if self . params . logits_all else 1
logits_view = llama_cpp . llama_get_logits ( self . ctx )
logits = [ [ logits_view [ i * cols + j ] for j in range ( cols ) ] for i in range ( rows ) ]
return logits
2023-04-02 04:02:47 +00:00
def sample (
self ,
top_k : int ,
top_p : float ,
temp : float ,
repeat_penalty : float ,
) :
""" Sample a token from the model.
Args :
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
Returns :
The sampled token .
"""
assert self . ctx is not None
return llama_cpp . llama_sample_top_p_top_k (
ctx = self . ctx ,
last_n_tokens_data = ( llama_cpp . llama_token * self . last_n_tokens_size ) (
* self . last_n_tokens_data
) ,
last_n_tokens_size = llama_cpp . c_int ( self . last_n_tokens_size ) ,
top_k = llama_cpp . c_int ( top_k ) ,
top_p = llama_cpp . c_float ( top_p ) ,
temp = llama_cpp . c_float ( temp ) ,
repeat_penalty = llama_cpp . c_float ( repeat_penalty ) ,
)
2023-04-01 17:01:27 +00:00
def generate (
self ,
tokens : Sequence [ llama_cpp . llama_token ] ,
top_k : int ,
top_p : float ,
temp : float ,
repeat_penalty : float ,
2023-04-13 04:28:00 +00:00
reset : bool = True ,
2023-04-01 17:01:27 +00:00
) - > Generator [
llama_cpp . llama_token , Optional [ Sequence [ llama_cpp . llama_token ] ] , None
] :
2023-04-02 04:02:47 +00:00
""" Create a generator of tokens from a prompt.
2023-04-01 21:36:30 +00:00
2023-04-01 21:39:35 +00:00
Examples :
>> > llama = Llama ( " models/ggml-7b.bin " )
>> > tokens = llama . tokenize ( b " Hello, world! " )
>> > for token in llama . generate ( tokens , top_k = 40 , top_p = 0.95 , temp = 1.0 , repeat_penalty = 1.1 ) :
. . . print ( llama . detokenize ( [ token ] ) )
2023-04-01 21:36:30 +00:00
Args :
tokens : The prompt tokens .
top_k : The top - k sampling parameter .
top_p : The top - p sampling parameter .
temp : The temperature parameter .
repeat_penalty : The repeat penalty parameter .
2023-04-13 04:28:00 +00:00
reset : Whether to reset the model state .
2023-04-01 21:36:30 +00:00
Yields :
The generated tokens .
"""
2023-04-01 17:01:27 +00:00
assert self . ctx is not None
2023-04-13 04:28:00 +00:00
if reset :
self . reset ( )
2023-04-01 17:01:27 +00:00
while True :
2023-04-02 04:02:47 +00:00
self . eval ( tokens )
token = self . sample (
top_k = top_k ,
top_p = top_p ,
temp = temp ,
repeat_penalty = repeat_penalty ,
2023-04-01 17:01:27 +00:00
)
tokens_or_none = yield token
tokens = [ token ]
if tokens_or_none is not None :
tokens . extend ( tokens_or_none )
def create_embedding ( self , input : str ) - > Embedding :
2023-03-28 08:59:54 +00:00
""" Embed a string.
Args :
2023-04-01 17:01:27 +00:00
input : The utf - 8 encoded string to embed .
2023-03-28 08:59:54 +00:00
Returns :
2023-04-01 17:01:27 +00:00
An embedding object .
2023-03-28 08:59:54 +00:00
"""
2023-04-01 17:01:27 +00:00
assert self . ctx is not None
2023-04-04 17:09:24 +00:00
2023-04-05 07:25:37 +00:00
if self . params . embedding == False :
raise RuntimeError (
" Llama model must be created with embedding=True to call this method "
)
2023-04-04 17:09:24 +00:00
if self . verbose :
llama_cpp . llama_reset_timings ( self . ctx )
2023-04-01 17:01:27 +00:00
tokens = self . tokenize ( input . encode ( " utf-8 " ) )
2023-04-02 04:02:47 +00:00
self . reset ( )
self . eval ( tokens )
2023-04-01 17:01:27 +00:00
n_tokens = len ( tokens )
embedding = llama_cpp . llama_get_embeddings ( self . ctx ) [
: llama_cpp . llama_n_embd ( self . ctx )
]
2023-04-04 17:09:24 +00:00
if self . verbose :
llama_cpp . llama_print_timings ( self . ctx )
2023-04-01 17:01:27 +00:00
return {
" object " : " list " ,
" data " : [
{
" object " : " embedding " ,
" embedding " : embedding ,
" index " : 0 ,
}
] ,
" model " : self . model_path ,
" usage " : {
" prompt_tokens " : n_tokens ,
" total_tokens " : n_tokens ,
} ,
}
2023-03-28 06:42:22 +00:00
2023-04-03 22:46:19 +00:00
def embed ( self , input : str ) - > List [ float ] :
""" Embed a string.
Args :
input : The utf - 8 encoded string to embed .
Returns :
A list of embeddings
"""
return list ( map ( float , self . create_embedding ( input ) [ " data " ] [ 0 ] [ " embedding " ] ) )
2023-04-01 17:01:27 +00:00
def _create_completion (
2023-03-23 09:33:06 +00:00
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
max_tokens : int = 16 ,
temperature : float = 0.8 ,
top_p : float = 0.95 ,
2023-03-23 19:51:05 +00:00
logprobs : Optional [ int ] = None ,
2023-03-23 09:33:06 +00:00
echo : bool = False ,
2023-04-14 13:59:08 +00:00
stop : Optional [ List [ str ] ] = [ ] ,
2023-03-23 09:33:06 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
2023-03-28 08:03:57 +00:00
stream : bool = False ,
2023-04-12 18:06:22 +00:00
) - > Union [ Iterator [ Completion ] , Iterator [ CompletionChunk ] ] :
2023-04-01 17:01:27 +00:00
assert self . ctx is not None
2023-03-28 06:42:22 +00:00
completion_id = f " cmpl- { str ( uuid . uuid4 ( ) ) } "
2023-03-28 08:03:57 +00:00
created = int ( time . time ( ) )
2023-04-01 17:01:27 +00:00
completion_tokens : List [ llama_cpp . llama_token ] = [ ]
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self . tokenize ( b " " + prompt . encode ( " utf-8 " ) )
text = b " "
2023-04-02 07:59:19 +00:00
returned_characters = 0
2023-04-14 13:59:08 +00:00
stop = stop if not None else [ ]
2023-03-23 09:33:06 +00:00
2023-04-04 17:09:24 +00:00
if self . verbose :
llama_cpp . llama_reset_timings ( self . ctx )
2023-04-01 17:01:27 +00:00
if len ( prompt_tokens ) + max_tokens > int ( llama_cpp . llama_n_ctx ( self . ctx ) ) :
2023-03-23 09:33:06 +00:00
raise ValueError (
2023-03-24 18:58:10 +00:00
f " Requested tokens exceed context window of { llama_cpp . llama_n_ctx ( self . ctx ) } "
2023-03-23 09:33:06 +00:00
)
2023-04-01 17:01:27 +00:00
if stop != [ ] :
2023-04-02 07:59:19 +00:00
stop_sequences = [ s . encode ( " utf-8 " ) for s in stop ]
2023-04-01 17:01:27 +00:00
else :
2023-04-02 07:59:19 +00:00
stop_sequences = [ ]
2023-03-24 18:33:38 +00:00
2023-04-12 18:05:11 +00:00
text_offset = 0
text_offsets : List [ int ] = [ ]
token_logprobs : List [ float ] = [ ]
tokens : List [ str ] = [ ]
top_logprobs : List [ Dict [ str , float ] ] = [ ]
self . reset ( )
self . eval ( prompt_tokens )
if logprobs is not None and self . params . logits_all is False :
raise ValueError (
" logprobs is not supported for models created with logits_all=False "
)
if logprobs is not None :
token_strs = [
self . detokenize ( [ token ] ) . decode ( " utf-8 " ) for token in prompt_tokens
]
logprobs_all = [
[ Llama . logit_to_logprob ( logit ) for logit in row ]
for row in self . all_logits
]
for token , token_str , logprobs_token in zip (
prompt_tokens , token_strs , logprobs_all
) :
text_offsets . append ( text_offset )
text_offset + = len ( token_str )
tokens . append ( token_str )
sorted_logprobs = list (
sorted (
zip ( logprobs_token , range ( len ( logprobs_token ) ) ) , reverse = True
)
)
token_logprobs . append ( sorted_logprobs [ int ( token ) ] [ 0 ] )
top_logprob = {
self . detokenize ( [ llama_cpp . llama_token ( i ) ] ) . decode ( " utf-8 " ) : logprob
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : sorted_logprobs [ int ( token ) ] [ 0 ] } )
top_logprobs . append ( top_logprob )
finish_reason = " length "
while True :
token = self . sample (
top_k = top_k ,
top_p = top_p ,
temp = temperature ,
repeat_penalty = repeat_penalty ,
)
2023-03-23 09:33:06 +00:00
if token == llama_cpp . llama_token_eos ( ) :
2023-04-02 07:59:19 +00:00
text = self . detokenize ( completion_tokens )
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-03-28 05:45:37 +00:00
completion_tokens . append ( token )
2023-03-23 09:33:06 +00:00
2023-04-02 07:59:19 +00:00
all_text = self . detokenize ( completion_tokens )
any_stop = [ s for s in stop_sequences if s in all_text ]
2023-03-23 09:33:06 +00:00
if len ( any_stop ) > 0 :
first_stop = any_stop [ 0 ]
2023-04-02 07:59:19 +00:00
text = all_text [ : all_text . index ( first_stop ) ]
2023-03-23 09:33:06 +00:00
finish_reason = " stop "
break
2023-03-28 08:03:57 +00:00
if stream :
2023-04-02 07:59:19 +00:00
start = returned_characters
2023-03-28 08:03:57 +00:00
longest = 0
2023-04-02 07:59:19 +00:00
# We want to avoid yielding any characters from
# the generated text if they are part of a stop
# sequence.
for s in stop_sequences :
2023-03-28 08:03:57 +00:00
for i in range ( len ( s ) , 0 , - 1 ) :
2023-04-02 07:59:19 +00:00
if all_text . endswith ( s [ : i ] ) :
2023-03-28 08:03:57 +00:00
if i > longest :
longest = i
break
2023-04-02 07:59:19 +00:00
text = all_text [ : len ( all_text ) - longest ]
returned_characters + = len ( text [ start : ] )
2023-03-28 08:03:57 +00:00
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : self . model_path ,
" choices " : [
{
2023-04-03 22:46:19 +00:00
" text " : text [ start : ] . decode ( " utf-8 " ) ,
2023-03-28 08:03:57 +00:00
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
}
2023-04-12 18:05:11 +00:00
if logprobs is not None :
# TODO: Confirm wether this should happen before or after
# next eval.
token_str = self . detokenize ( [ token ] ) . decode ( " utf-8 " )
text_offsets . append ( text_offset )
text_offset + = len ( token_str )
tokens . append ( token_str )
logprobs_token = [
Llama . logit_to_logprob ( logit ) for logit in self . all_logits [ - 1 ]
]
sorted_logprobs = list (
sorted (
zip ( logprobs_token , range ( len ( logprobs_token ) ) ) , reverse = True
)
)
token_logprobs . append ( sorted_logprobs [ int ( token ) ] [ 0 ] )
top_logprob = {
self . detokenize ( [ llama_cpp . llama_token ( i ) ] ) . decode ( " utf-8 " ) : logprob
for logprob , i in sorted_logprobs [ : logprobs ]
}
top_logprob . update ( { token_str : logprobs_token [ int ( token ) ] } )
top_logprobs . append ( top_logprob )
2023-04-02 07:59:19 +00:00
if len ( completion_tokens ) > = max_tokens :
text = self . detokenize ( completion_tokens )
finish_reason = " length "
break
2023-04-12 18:05:11 +00:00
self . eval ( [ token ] )
2023-03-23 09:33:06 +00:00
2023-03-28 08:03:57 +00:00
if stream :
yield {
" id " : completion_id ,
" object " : " text_completion " ,
" created " : created ,
" model " : self . model_path ,
" choices " : [
{
2023-04-02 07:59:19 +00:00
" text " : text [ returned_characters : ] . decode ( " utf-8 " ) ,
2023-03-28 08:03:57 +00:00
" index " : 0 ,
" logprobs " : None ,
" finish_reason " : finish_reason ,
}
] ,
}
return
2023-03-23 20:25:13 +00:00
text = text . decode ( " utf-8 " )
2023-03-23 09:33:06 +00:00
if echo :
text = prompt + text
if suffix is not None :
text = text + suffix
2023-04-12 18:05:11 +00:00
logprobs_or_none : Optional [ CompletionLogprobs ] = None
2023-03-23 19:51:05 +00:00
if logprobs is not None :
2023-04-12 18:05:11 +00:00
logprobs_or_none = {
" tokens " : tokens ,
" text_offset " : text_offsets ,
" token_logprobs " : token_logprobs ,
" top_logprobs " : top_logprobs ,
}
2023-03-23 19:51:05 +00:00
2023-04-04 17:09:24 +00:00
if self . verbose :
llama_cpp . llama_print_timings ( self . ctx )
2023-03-28 08:03:57 +00:00
yield {
2023-03-28 06:42:22 +00:00
" id " : completion_id ,
2023-03-23 09:33:06 +00:00
" object " : " text_completion " ,
2023-03-28 06:42:22 +00:00
" created " : created ,
2023-03-24 08:04:29 +00:00
" model " : self . model_path ,
2023-03-23 09:33:06 +00:00
" choices " : [
{
" text " : text ,
" index " : 0 ,
2023-04-12 18:05:11 +00:00
" logprobs " : logprobs_or_none ,
2023-03-23 09:33:06 +00:00
" finish_reason " : finish_reason ,
}
] ,
" usage " : {
2023-03-28 05:45:37 +00:00
" prompt_tokens " : len ( prompt_tokens ) ,
" completion_tokens " : len ( completion_tokens ) ,
" total_tokens " : len ( prompt_tokens ) + len ( completion_tokens ) ,
2023-03-23 09:33:06 +00:00
} ,
}
2023-04-01 17:01:27 +00:00
def create_completion (
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
max_tokens : int = 128 ,
temperature : float = 0.8 ,
top_p : float = 0.95 ,
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-04-14 13:59:08 +00:00
stop : Optional [ List [ str ] ] = [ ] ,
2023-04-01 17:01:27 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-04-04 00:12:14 +00:00
) - > Union [ Completion , Iterator [ CompletionChunk ] ] :
2023-04-01 17:01:27 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
max_tokens : The maximum number of tokens to generate .
temperature : The temperature to use for sampling .
top_p : The top - p value to use for sampling .
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
repeat_penalty : The penalty to apply to repeated tokens .
top_k : The top - k value to use for sampling .
stream : Whether to stream the results .
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
completion_or_chunks = self . _create_completion (
prompt = prompt ,
suffix = suffix ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
logprobs = logprobs ,
echo = echo ,
stop = stop ,
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
)
if stream :
2023-04-04 00:12:14 +00:00
chunks : Iterator [ CompletionChunk ] = completion_or_chunks
2023-04-01 17:01:27 +00:00
return chunks
completion : Completion = next ( completion_or_chunks ) # type: ignore
return completion
2023-03-28 08:03:57 +00:00
def __call__ (
self ,
prompt : str ,
suffix : Optional [ str ] = None ,
2023-04-01 17:01:27 +00:00
max_tokens : int = 128 ,
2023-03-28 08:03:57 +00:00
temperature : float = 0.8 ,
top_p : float = 0.95 ,
logprobs : Optional [ int ] = None ,
echo : bool = False ,
2023-04-14 13:59:08 +00:00
stop : Optional [ List [ str ] ] = [ ] ,
2023-03-28 08:03:57 +00:00
repeat_penalty : float = 1.1 ,
top_k : int = 40 ,
stream : bool = False ,
2023-04-04 00:26:08 +00:00
) - > Union [ Completion , Iterator [ CompletionChunk ] ] :
2023-03-28 08:03:57 +00:00
""" Generate text from a prompt.
Args :
prompt : The prompt to generate text from .
suffix : A suffix to append to the generated text . If None , no suffix is appended .
max_tokens : The maximum number of tokens to generate .
temperature : The temperature to use for sampling .
top_p : The top - p value to use for sampling .
logprobs : The number of logprobs to return . If None , no logprobs are returned .
echo : Whether to echo the prompt .
stop : A list of strings to stop generation when encountered .
repeat_penalty : The penalty to apply to repeated tokens .
top_k : The top - k value to use for sampling .
stream : Whether to stream the results .
Raises :
ValueError : If the requested tokens exceed the context window .
RuntimeError : If the prompt fails to tokenize or the model fails to evaluate the prompt .
Returns :
Response object containing the generated text .
"""
2023-04-01 17:01:27 +00:00
return self . create_completion (
2023-03-28 08:03:57 +00:00
prompt = prompt ,
suffix = suffix ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
logprobs = logprobs ,
echo = echo ,
stop = stop ,
repeat_penalty = repeat_penalty ,
top_k = top_k ,
stream = stream ,
)
2023-04-04 00:12:44 +00:00
def _convert_text_completion_to_chat (
self , completion : Completion
) - > ChatCompletion :
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" finish_reason " : completion [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
" usage " : completion [ " usage " ] ,
}
def _convert_text_completion_chunks_to_chat (
self ,
chunks : Iterator [ CompletionChunk ] ,
) - > Iterator [ ChatCompletionChunk ] :
for i , chunk in enumerate ( chunks ) :
if i == 0 :
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" role " : " assistant " ,
} ,
" finish_reason " : None ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" content " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" finish_reason " : chunk [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
}
def create_chat_completion (
self ,
messages : List [ ChatCompletionMessage ] ,
temperature : float = 0.8 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
stream : bool = False ,
2023-04-14 13:59:08 +00:00
stop : Optional [ List [ str ] ] = [ ] ,
2023-04-04 00:12:44 +00:00
max_tokens : int = 128 ,
repeat_penalty : float = 1.1 ,
) - > Union [ ChatCompletion , Iterator [ ChatCompletionChunk ] ] :
2023-04-04 00:24:20 +00:00
""" Generate a chat completion from a list of messages.
Args :
messages : A list of messages to generate a response for .
temperature : The temperature to use for sampling .
top_p : The top - p value to use for sampling .
top_k : The top - k value to use for sampling .
stream : Whether to stream the results .
stop : A list of strings to stop generation when encountered .
max_tokens : The maximum number of tokens to generate .
repeat_penalty : The penalty to apply to repeated tokens .
Returns :
Generated chat completion or a stream of chat completion chunks .
"""
2023-04-14 13:59:08 +00:00
stop = stop if not None else [ ]
2023-04-04 00:12:44 +00:00
instructions = """ Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions. """
chat_history = " \n " . join (
f ' { message [ " role " ] } { message . get ( " user " , " " ) } : { message [ " content " ] } '
for message in messages
)
PROMPT = f " \n \n ### Instructions: { instructions } \n \n ### Inputs: { chat_history } \n \n ### Response: \n assistant: "
PROMPT_STOP = [ " ### " , " \n user: " , " \n assistant: " , " \n system: " ]
completion_or_chunks = self (
prompt = PROMPT ,
stop = PROMPT_STOP + stop ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
stream = stream ,
max_tokens = max_tokens ,
repeat_penalty = repeat_penalty ,
)
if stream :
chunks : Iterator [ CompletionChunk ] = completion_or_chunks # type: ignore
return self . _convert_text_completion_chunks_to_chat ( chunks )
else :
completion : Completion = completion_or_chunks # type: ignore
return self . _convert_text_completion_to_chat ( completion )
2023-03-23 09:33:06 +00:00
def __del__ ( self ) :
2023-04-01 17:01:27 +00:00
if self . ctx is not None :
llama_cpp . llama_free ( self . ctx )
self . ctx = None
2023-04-01 21:29:30 +00:00
2023-04-05 10:52:17 +00:00
def __getstate__ ( self ) :
return dict (
verbose = self . verbose ,
model_path = self . model_path ,
n_ctx = self . params . n_ctx ,
n_parts = self . params . n_parts ,
seed = self . params . seed ,
f16_kv = self . params . f16_kv ,
logits_all = self . params . logits_all ,
vocab_only = self . params . vocab_only ,
2023-04-10 06:11:35 +00:00
use_mmap = self . params . use_mmap ,
2023-04-05 10:52:17 +00:00
use_mlock = self . params . use_mlock ,
embedding = self . params . embedding ,
last_n_tokens_size = self . last_n_tokens_size ,
n_batch = self . n_batch ,
n_threads = self . n_threads ,
)
def __setstate__ ( self , state ) :
self . __init__ (
model_path = state [ " model_path " ] ,
n_ctx = state [ " n_ctx " ] ,
n_parts = state [ " n_parts " ] ,
seed = state [ " seed " ] ,
f16_kv = state [ " f16_kv " ] ,
logits_all = state [ " logits_all " ] ,
vocab_only = state [ " vocab_only " ] ,
2023-04-10 06:11:35 +00:00
use_mmap = state [ " use_mmap " ] ,
2023-04-05 10:52:17 +00:00
use_mlock = state [ " use_mlock " ] ,
embedding = state [ " embedding " ] ,
n_threads = state [ " n_threads " ] ,
n_batch = state [ " n_batch " ] ,
last_n_tokens_size = state [ " last_n_tokens_size " ] ,
verbose = state [ " verbose " ] ,
)
2023-04-01 21:29:30 +00:00
@staticmethod
def token_eos ( ) - > llama_cpp . llama_token :
""" Return the end-of-sequence token. """
return llama_cpp . llama_token_eos ( )
@staticmethod
def token_bos ( ) - > llama_cpp . llama_token :
""" Return the beginning-of-sequence token. """
return llama_cpp . llama_token_bos ( )
2023-04-12 18:05:11 +00:00
@staticmethod
def logit_to_logprob ( x : float ) - > float :
return math . log ( 1.0 + math . exp ( x ) )