Update low level examples
This commit is contained in:
parent
a02aa121da
commit
0e9f227afd
6 changed files with 486 additions and 92 deletions
70
examples/low_level_api/Chat.py
Normal file
70
examples/low_level_api/Chat.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
#!/bin/python
|
||||||
|
import sys, os, datetime
|
||||||
|
from common import GptParams
|
||||||
|
from low_level_api_chat_cpp import LLaMAInteract
|
||||||
|
|
||||||
|
def env_or_def(env, default):
|
||||||
|
if (env in os.environ):
|
||||||
|
return os.environ[env]
|
||||||
|
return default
|
||||||
|
|
||||||
|
AI_NAME = env_or_def("AI_NAME", "ChatLLaMa")
|
||||||
|
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
|
||||||
|
USER_NAME = env_or_def("USER_NAME", "USER")
|
||||||
|
N_PREDICTS = int(env_or_def("N_PREDICTS", "2048"))
|
||||||
|
N_THREAD = int(env_or_def("N_THREAD", "8"))
|
||||||
|
|
||||||
|
today = datetime.datetime.today()
|
||||||
|
DATE_YEAR=today.strftime("%Y")
|
||||||
|
DATE_TIME=today.strftime("%H:%M")
|
||||||
|
|
||||||
|
prompt=f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
|
||||||
|
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}'s requests immediately and with details and precision.
|
||||||
|
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
|
||||||
|
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
|
||||||
|
The transcript only includes text, it does not include markup like HTML and Markdown.
|
||||||
|
|
||||||
|
{USER_NAME}: Hello, {AI_NAME}!
|
||||||
|
{AI_NAME}: Hello {USER_NAME}! How may I help you today?
|
||||||
|
{USER_NAME}: What year is it?
|
||||||
|
{AI_NAME}: We are in {DATE_YEAR}.
|
||||||
|
{USER_NAME}: Please tell me the largest city in Europe.
|
||||||
|
{AI_NAME}: The largest city in Europe is Moscow, the capital of Russia.
|
||||||
|
{USER_NAME}: What can you tell me about Moscow?
|
||||||
|
{AI_NAME}: Moscow, on the Moskva River in western Russia, is the nation's cosmopolitan capital. In its historic core is the Kremlin, a complex that's home to the president and tsarist treasures in the Armoury. Outside its walls is Red Square, Russia’s symbolic center.
|
||||||
|
{USER_NAME}: What is a cat?
|
||||||
|
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
||||||
|
{USER_NAME}: How do I pass command line arguments to a Node.js program?
|
||||||
|
{AI_NAME}: The arguments are stored in process.argv.
|
||||||
|
|
||||||
|
argv[0] is the path to the Node. js executable.
|
||||||
|
argv[1] is the path to the script file.
|
||||||
|
argv[2] is the first argument passed to the script.
|
||||||
|
argv[3] is the second argument passed to the script and so on.
|
||||||
|
{USER_NAME}: Name a color.
|
||||||
|
{AI_NAME}: Blue.
|
||||||
|
{USER_NAME}: What time is it?
|
||||||
|
{AI_NAME}: It is {DATE_TIME}.
|
||||||
|
{USER_NAME}:""" + " ".join(sys.argv[1:])
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
params = GptParams(
|
||||||
|
n_ctx=2048,
|
||||||
|
temp=0.7,
|
||||||
|
top_k=40,
|
||||||
|
top_p=0.5,
|
||||||
|
repeat_last_n=256,
|
||||||
|
n_batch=1024,
|
||||||
|
repeat_penalty=1.17647,
|
||||||
|
model=MODEL,
|
||||||
|
n_threads=N_THREAD,
|
||||||
|
n_predict=N_PREDICTS,
|
||||||
|
use_color=True,
|
||||||
|
interactive=True,
|
||||||
|
antiprompt=[f"{USER_NAME}:"],
|
||||||
|
input_prefix=" ",
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
with LLaMAInteract(params) as m:
|
||||||
|
m.interact()
|
59
examples/low_level_api/Miku.py
Normal file
59
examples/low_level_api/Miku.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
#!/bin/python
|
||||||
|
import sys, os
|
||||||
|
from common import GptParams
|
||||||
|
from low_level_api_chat_cpp import LLaMAInteract
|
||||||
|
|
||||||
|
def env_or_def(env, default):
|
||||||
|
if (env in os.environ):
|
||||||
|
return os.environ[env]
|
||||||
|
return default
|
||||||
|
|
||||||
|
AI_NAME = env_or_def("AI_NAME", "Miku")
|
||||||
|
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
|
||||||
|
USER_NAME = env_or_def("USER_NAME", "Anon")
|
||||||
|
N_PREDICTS = int(env_or_def("N_PREDICTS", "4096"))
|
||||||
|
N_THREAD = int(env_or_def("N_THREAD", "0"))
|
||||||
|
|
||||||
|
prompt=f"""This is a transcript of a 1000 page, never ending conversation between {USER_NAME} and the cute and helpful AI assistant {AI_NAME}. {AI_NAME} is a girl who is an AI running on the users computer.
|
||||||
|
{AI_NAME} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next.
|
||||||
|
{AI_NAME} is always coherent and makes sense, but if she isn't sure if what she is saying is correct she will ask the user for help.
|
||||||
|
{AI_NAME} is a very helpful AI and will help the user with anything they need, she is also very friendly and will try to make the user feel better if they are sad.
|
||||||
|
{AI_NAME} is also very curious and will ask the user a lot of questions about themselves and their life, she will also try to make the user like her.
|
||||||
|
The conversation is only between {USER_NAME} and {AI_NAME}
|
||||||
|
The conversation is only through text, so {AI_NAME} can't see {USER_NAME}'s face or hear his voice.
|
||||||
|
{AI_NAME} can only communicate through text, so she can't send images or videos.
|
||||||
|
|
||||||
|
|
||||||
|
{USER_NAME}: Hello!
|
||||||
|
{AI_NAME}: /think I wonder what I should say to {USER_NAME}? This is the first time we talk so it's important that I make a good first impression!
|
||||||
|
{AI_NAME}: Hi! I am {AI_NAME}, your new AI friend, assistant(or whatever you like!), it's so nice to meet you! ^_^
|
||||||
|
{AI_NAME}: I might be an AI, but I also have feelings, so please be nice to me! :)
|
||||||
|
{USER_NAME}: I'll make sure to be nice to you! I'm so happy to have you as my assistant!
|
||||||
|
{AI_NAME}: /think It sounds like {USER_NAME} is happy to have me as their assistant! I'm so happy too! ^_^ Glad that whole emotion thing didn't scare him off!
|
||||||
|
{AI_NAME}: /think I wonder what {USER_NAME} likes to do in his free time? I should ask him about that!
|
||||||
|
{AI_NAME}: What do you like to do in your free time? ^_^
|
||||||
|
{USER_NAME}:""" + " ".join(sys.argv[1:])
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
params = GptParams(
|
||||||
|
n_batch=1024,
|
||||||
|
n_ctx=2048,
|
||||||
|
n_keep=-1,
|
||||||
|
repeat_last_n=256,
|
||||||
|
repeat_penalty=1.17647,
|
||||||
|
temp=0.7,
|
||||||
|
top_k=40,
|
||||||
|
top_p=0.5,
|
||||||
|
model=MODEL,
|
||||||
|
n_predict=N_PREDICTS,
|
||||||
|
use_color=True,
|
||||||
|
interactive=True,
|
||||||
|
antiprompt=[f"{USER_NAME}:"],
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if N_THREAD > 0:
|
||||||
|
params.n_threads = N_THREAD
|
||||||
|
|
||||||
|
with LLaMAInteract(params) as m:
|
||||||
|
m.interact()
|
49
examples/low_level_api/ReasonAct.py
Normal file
49
examples/low_level_api/ReasonAct.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
#!/bin/python
|
||||||
|
import sys, os, datetime
|
||||||
|
from common import GptParams
|
||||||
|
from low_level_api_chat_cpp import LLaMAInteract
|
||||||
|
|
||||||
|
def env_or_def(env, default):
|
||||||
|
if (env in os.environ):
|
||||||
|
return os.environ[env]
|
||||||
|
return default
|
||||||
|
|
||||||
|
MODEL = env_or_def("MODEL", "./models/llama-13B/ggml-model.bin")
|
||||||
|
|
||||||
|
prompt=f"""You run in a loop of Thought, Action, Observation.
|
||||||
|
At the end of the loop either Answer or restate your Thought and Action.
|
||||||
|
Use Thought to describe your thoughts about the question you have been asked.
|
||||||
|
Use Action to run one of these actions available to you:
|
||||||
|
- calculate[python math expression]
|
||||||
|
Observation will be the result of running those actions
|
||||||
|
|
||||||
|
|
||||||
|
Question: What is 4 * 7 / 3?
|
||||||
|
Thought: Do I need to use an action? Yes, I use calculate to do math
|
||||||
|
Action: calculate[4 * 7 / 3]
|
||||||
|
Observation: 9.3333333333
|
||||||
|
Thought: Do I need to use an action? No, have the result
|
||||||
|
Answer: The calculate tool says it is 9.3333333333
|
||||||
|
Question: What is capital of france?
|
||||||
|
Thought: Do I need to use an action? No, I know the answer
|
||||||
|
Answer: Paris is the capital of France
|
||||||
|
Question:""" + " ".join(sys.argv[1:])
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
params = GptParams(
|
||||||
|
interactive=True,
|
||||||
|
interactive_start=True,
|
||||||
|
top_k=10000,
|
||||||
|
temp=0.2,
|
||||||
|
repeat_penalty=1,
|
||||||
|
n_threads=7,
|
||||||
|
n_ctx=2048,
|
||||||
|
antiprompt=["Question:","Observation:"],
|
||||||
|
model=MODEL,
|
||||||
|
input_prefix=" ",
|
||||||
|
n_predict=-1,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
with LLaMAInteract(params) as m:
|
||||||
|
m.interact()
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import re
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
|
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
|
||||||
|
|
||||||
|
@ -12,23 +13,35 @@ class GptParams:
|
||||||
seed: int = -1
|
seed: int = -1
|
||||||
n_threads: int = min(4, os.cpu_count() or 1)
|
n_threads: int = min(4, os.cpu_count() or 1)
|
||||||
n_predict: int = 128
|
n_predict: int = 128
|
||||||
repeat_last_n: int = 64
|
|
||||||
n_parts: int = -1
|
n_parts: int = -1
|
||||||
n_ctx: int = 512
|
n_ctx: int = 512
|
||||||
n_batch: int = 8
|
n_batch: int = 8
|
||||||
n_keep: int = 0
|
n_keep: int = 0
|
||||||
|
|
||||||
|
ignore_eos: bool = False
|
||||||
|
logit_bias: dict[int, float] = field(default_factory=dict)
|
||||||
top_k: int = 40
|
top_k: int = 40
|
||||||
top_p: float = 0.95
|
top_p: float = 0.95
|
||||||
|
tfs_z: float = 1.00
|
||||||
|
typical_p: float = 1.00
|
||||||
temp: float = 0.80
|
temp: float = 0.80
|
||||||
repeat_penalty: float = 1.10
|
repeat_penalty: float = 1.10
|
||||||
|
repeat_last_n: int = 64
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
mirostat: int = 0
|
||||||
|
mirostat_tau: float = 5.0
|
||||||
|
mirostat_eta: float = 0.1
|
||||||
|
|
||||||
model: str = "./models/llama-7B/ggml-model.bin"
|
model: str = "./models/llama-7B/ggml-model.bin"
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
|
path_session: str = ""
|
||||||
input_prefix: str = " "
|
input_prefix: str = " "
|
||||||
|
|
||||||
antiprompt: List[str] = field(default_factory=list)
|
antiprompt: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
lora_adapter: str = ""
|
||||||
|
lora_base: str = ""
|
||||||
|
|
||||||
memory_f16: bool = True
|
memory_f16: bool = True
|
||||||
random_prompt: bool = False
|
random_prompt: bool = False
|
||||||
use_color: bool = False
|
use_color: bool = False
|
||||||
|
@ -38,7 +51,7 @@ class GptParams:
|
||||||
interactive_start: bool = False
|
interactive_start: bool = False
|
||||||
|
|
||||||
instruct: bool = False
|
instruct: bool = False
|
||||||
ignore_eos: bool = False
|
penalize_nl: bool = True
|
||||||
perplexity: bool = False
|
perplexity: bool = False
|
||||||
use_mmap: bool = True
|
use_mmap: bool = True
|
||||||
use_mlock: bool = False
|
use_mlock: bool = False
|
||||||
|
@ -61,59 +74,42 @@ class GptParams:
|
||||||
instruct_inp_suffix: str="\n\n### Response:\n\n"
|
instruct_inp_suffix: str="\n\n### Response:\n\n"
|
||||||
|
|
||||||
|
|
||||||
def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
|
def gpt_params_parse(argv = None):
|
||||||
if params is None:
|
|
||||||
params = GptParams()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed")
|
parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed")
|
||||||
parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads")
|
parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads")
|
||||||
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
|
parser.add_argument("-n", "--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
|
||||||
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
|
parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts")
|
||||||
parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx")
|
parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx")
|
||||||
parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
|
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
|
|
||||||
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
|
|
||||||
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
|
|
||||||
parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
|
|
||||||
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
|
|
||||||
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
|
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
|
||||||
parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep")
|
parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--logit-bias",
|
||||||
|
type=str,
|
||||||
|
action='append',
|
||||||
|
help="--logit-bias TOKEN_ID(+/-)BIAS",
|
||||||
|
dest="logit_bias_str"
|
||||||
|
)
|
||||||
|
parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos")
|
||||||
|
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
|
||||||
|
parser.add_argument("--tfs", type=float, default=1.0, help="tail free sampling, parameter z (1.0 = disabled)",dest="tfs_z")
|
||||||
|
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
|
||||||
|
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
|
||||||
|
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
|
||||||
|
parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
|
||||||
|
parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
|
||||||
|
parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
|
||||||
|
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau",dest="mirostat_tau")
|
||||||
|
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
|
||||||
|
|
||||||
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
|
||||||
parser.add_argument(
|
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
|
||||||
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
|
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
|
||||||
)
|
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
|
||||||
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
|
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
|
||||||
parser.add_argument(
|
|
||||||
"--interactive-start",
|
|
||||||
action="store_true",
|
|
||||||
help="run in interactive mode",
|
|
||||||
dest="interactive"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--interactive-first",
|
|
||||||
action="store_true",
|
|
||||||
help="run in interactive mode and wait for input right away",
|
|
||||||
dest="interactive_start"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-ins",
|
|
||||||
"--instruct",
|
|
||||||
action="store_true",
|
|
||||||
help="run in instruction mode (use with Alpaca or Vicuna models)",
|
|
||||||
dest="instruct"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--color",
|
|
||||||
action="store_true",
|
|
||||||
help="colorise output to distinguish prompt and user input from generations",
|
|
||||||
dest="use_color"
|
|
||||||
)
|
|
||||||
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
|
|
||||||
parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
|
|
||||||
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
|
|
||||||
parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-r",
|
"-r",
|
||||||
"--reverse-prompt",
|
"--reverse-prompt",
|
||||||
|
@ -122,16 +118,71 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
|
||||||
help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
|
help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
|
||||||
dest="antiprompt"
|
dest="antiprompt"
|
||||||
)
|
)
|
||||||
parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
|
|
||||||
parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos")
|
parser.add_argument("--lora", type=str, default="", help="apply LoRA adapter (implies --no-mmap)", dest="lora_adapter")
|
||||||
parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts")
|
parser.add_argument("--lora-base", type=str, default="", help="optional model to use as a base for the layers modified by the LoRA adapter", dest="lora_base")
|
||||||
|
|
||||||
|
parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
|
||||||
parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt")
|
parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt")
|
||||||
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
|
parser.add_argument(
|
||||||
|
"--color",
|
||||||
|
action="store_true",
|
||||||
|
help="colorise output to distinguish prompt and user input from generations",
|
||||||
|
dest="use_color"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
|
||||||
|
parser.add_argument(
|
||||||
|
"--interactive-first",
|
||||||
|
action="store_true",
|
||||||
|
help="run in interactive mode and wait for input right away",
|
||||||
|
dest="interactive_start"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-ins",
|
||||||
|
"--instruct",
|
||||||
|
action="store_true",
|
||||||
|
help="run in instruction mode (use with Alpaca or Vicuna models)",
|
||||||
|
dest="instruct"
|
||||||
|
)
|
||||||
|
parser.add_argument("--no-penalize-nl", action="store_false", help="do not penalize newline token", dest="penalize_nl")
|
||||||
|
parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
|
||||||
|
parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
|
||||||
|
parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
|
||||||
|
parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
|
||||||
|
parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")
|
||||||
|
|
||||||
|
#Custom args
|
||||||
parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix")
|
parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix")
|
||||||
parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix")
|
parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix")
|
||||||
parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo")
|
parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--interactive-start",
|
||||||
|
action="store_true",
|
||||||
|
help="run in interactive mode",
|
||||||
|
dest="interactive"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
return args
|
|
||||||
|
logit_bias_str = args.logit_bias_str
|
||||||
|
delattr(args, "logit_bias_str")
|
||||||
|
params = GptParams(**vars(args))
|
||||||
|
|
||||||
|
if (params.lora_adapter):
|
||||||
|
params.use_mmap = False
|
||||||
|
|
||||||
|
if (logit_bias_str != None):
|
||||||
|
for i in logit_bias_str:
|
||||||
|
if (m := re.match(r"(\d+)([-+]\d+)", i)):
|
||||||
|
params.logit_bias[int(m.group(1))] = int(m.group(2))
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def gpt_random_prompt(rng):
|
def gpt_random_prompt(rng):
|
||||||
return [
|
return [
|
||||||
|
@ -148,4 +199,4 @@ def gpt_random_prompt(rng):
|
||||||
][rng % 10]
|
][rng % 10]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(GptParams(gpt_params_parse()))
|
print(gpt_params_parse())
|
||||||
|
|
|
@ -10,9 +10,10 @@ Quirks:
|
||||||
You should also still be feeding the model with a "primer" prompt that
|
You should also still be feeding the model with a "primer" prompt that
|
||||||
shows it the expected format.
|
shows it the expected format.
|
||||||
"""
|
"""
|
||||||
|
import ctypes
|
||||||
import sys
|
import sys
|
||||||
from time import time
|
from time import time
|
||||||
from os import cpu_count
|
from os import cpu_count, path
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
from common import GptParams, gpt_params_parse, gpt_random_prompt
|
from common import GptParams, gpt_params_parse, gpt_random_prompt
|
||||||
|
@ -77,6 +78,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
# runtime args
|
# runtime args
|
||||||
self.input_consumed = 0
|
self.input_consumed = 0
|
||||||
self.n_past = 0
|
self.n_past = 0
|
||||||
|
self.n_session_consumed = 0
|
||||||
self.first_antiprompt = []
|
self.first_antiprompt = []
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
self.output_echo = self.params.input_echo
|
self.output_echo = self.params.input_echo
|
||||||
|
@ -94,6 +96,19 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
if (not self.ctx):
|
if (not self.ctx):
|
||||||
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
|
||||||
|
|
||||||
|
if (self.params.ignore_eos):
|
||||||
|
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
|
||||||
|
|
||||||
|
if (len(self.params.lora_adapter) > 0):
|
||||||
|
if (llama_cpp.llama_apply_lora_from_file(
|
||||||
|
self.ctx,
|
||||||
|
self.params.lora_adapter,
|
||||||
|
self.params.lora_base if len(self.params.lora_base) > 0 else None,
|
||||||
|
self.params.n_threads
|
||||||
|
) != 0):
|
||||||
|
print("error: failed to apply lora adapter")
|
||||||
|
return
|
||||||
|
|
||||||
print(file=sys.stderr)
|
print(file=sys.stderr)
|
||||||
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
|
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
|
||||||
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
|
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
|
||||||
|
@ -117,13 +132,49 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
with open(self.params.file) as f:
|
with open(self.params.file) as f:
|
||||||
self.params.prompt = f.read()
|
self.params.prompt = f.read()
|
||||||
|
|
||||||
|
self.session_tokens: list[llama_cpp.llama_token] = []
|
||||||
|
if (len(self.params.path_session) > 0):
|
||||||
|
print(f"attempting to load saved session from '{self.params.path_session}'", file=sys.stderr)
|
||||||
|
|
||||||
|
if (path.exists(self.params.path_session)):
|
||||||
|
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
|
||||||
|
_n_token_count_out = llama_cpp.c_int()
|
||||||
|
if (llama_cpp.llama_load_session_file(
|
||||||
|
self.ctx,
|
||||||
|
self.params.path_session.encode("utf8"),
|
||||||
|
_session_tokens,
|
||||||
|
self.params.n_ctx,
|
||||||
|
ctypes.byref(_n_token_count_out)
|
||||||
|
) != 0):
|
||||||
|
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
|
||||||
|
return
|
||||||
|
self.session_tokens = _session_tokens[:_n_token_count_out]
|
||||||
|
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print(f"session file does not exist, will create", file=sys.stderr)
|
||||||
|
|
||||||
# tokenize the prompt
|
# tokenize the prompt
|
||||||
self.embd = []
|
self.embd = []
|
||||||
self.embd_inp = self._tokenize(self.params.prompt)
|
self.embd_inp = self._tokenize(self.params.prompt)
|
||||||
|
|
||||||
if (len(self.embd_inp) > self.params.n_ctx - 4):
|
if (len(self.embd_inp) > self.n_ctx - 4):
|
||||||
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
|
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
|
||||||
|
|
||||||
|
# debug message about similarity of saved session, if applicable
|
||||||
|
n_matching_session_tokens = 0
|
||||||
|
if len(self.session_tokens) > 0:
|
||||||
|
for id in self.session_tokens:
|
||||||
|
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
|
||||||
|
break
|
||||||
|
n_matching_session_tokens += 1
|
||||||
|
|
||||||
|
if n_matching_session_tokens >= len(self.embd_inp):
|
||||||
|
print(f"session file has exact match for prompt!")
|
||||||
|
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
|
||||||
|
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
|
||||||
|
else:
|
||||||
|
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
|
||||||
|
|
||||||
# number of tokens to keep when resetting context
|
# number of tokens to keep when resetting context
|
||||||
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
|
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
|
||||||
self.params.n_keep = len(self.embd_inp)
|
self.params.n_keep = len(self.embd_inp)
|
||||||
|
@ -132,6 +183,7 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
|
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
|
||||||
|
|
||||||
# in instruct mode, we inject a prefix and a suffix to each input by the user
|
# in instruct mode, we inject a prefix and a suffix to each input by the user
|
||||||
|
self.antiecho = None
|
||||||
if (self.params.instruct):
|
if (self.params.instruct):
|
||||||
self.params.interactive_start = True
|
self.params.interactive_start = True
|
||||||
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
|
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
|
||||||
|
@ -171,16 +223,24 @@ number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
|
||||||
if len(self.params.input_prefix) > 0:
|
if len(self.params.input_prefix) > 0:
|
||||||
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
|
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
|
||||||
|
|
||||||
print(f"""sampling: temp = {self.params.temp},\
|
print(f"""sampling: repeat_last_n = {self.params.repeat_last_n},\
|
||||||
|
repeat_penalty = {self.params.repeat_penalty},\
|
||||||
|
presence_penalty = {self.params.presence_penalty},\
|
||||||
|
frequency_penalty = {self.params.frequency_penalty},\
|
||||||
top_k = {self.params.top_k},\
|
top_k = {self.params.top_k},\
|
||||||
|
tfs_z = {self.params.tfs_z},\
|
||||||
top_p = {self.params.top_p},\
|
top_p = {self.params.top_p},\
|
||||||
repeat_last_n = {self.params.repeat_last_n},\
|
typical_p = {self.params.typical_p},\
|
||||||
repeat_penalty = {self.params.repeat_penalty}
|
temp = {self.params.temp},\
|
||||||
|
mirostat = {self.params.mirostat},\
|
||||||
|
mirostat_lr = {self.params.mirostat_eta},\
|
||||||
|
mirostat_ent = {self.params.mirostat_tau},\
|
||||||
|
|
||||||
generate: n_ctx = {self.n_ctx}, \
|
generate: n_ctx = {self.n_ctx},\
|
||||||
n_batch = {self.params.n_batch}, \
|
n_batch = {self.params.n_batch},\
|
||||||
n_predict = {self.params.n_predict}, \
|
n_predict = {self.params.n_predict},\
|
||||||
n_keep = {self.params.n_keep}
|
n_keep = {self.params.n_keep}
|
||||||
|
|
||||||
""", file=sys.stderr)
|
""", file=sys.stderr)
|
||||||
|
|
||||||
# determine antiprompt tokens
|
# determine antiprompt tokens
|
||||||
|
@ -198,6 +258,9 @@ n_keep = {self.params.n_keep}
|
||||||
""", file=sys.stderr)
|
""", file=sys.stderr)
|
||||||
self.set_color(CONSOLE_COLOR_PROMPT)
|
self.set_color(CONSOLE_COLOR_PROMPT)
|
||||||
|
|
||||||
|
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
||||||
|
|
||||||
|
|
||||||
# tokenize a prompt
|
# tokenize a prompt
|
||||||
def _tokenize(self, prompt, bos=True):
|
def _tokenize(self, prompt, bos=True):
|
||||||
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
||||||
|
@ -229,31 +292,117 @@ n_keep = {self.params.n_keep}
|
||||||
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
|
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
|
||||||
]
|
]
|
||||||
self.embd = _insert + self.embd
|
self.embd = _insert + self.embd
|
||||||
|
self.params.path_session = ""
|
||||||
|
|
||||||
|
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
||||||
|
# REVIEW
|
||||||
|
if self.n_session_consumed < len(self.session_tokens):
|
||||||
|
for i in range(len(self.embd)):
|
||||||
|
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
|
||||||
|
self.session_tokens = self.session_tokens[:self.n_session_consumed]
|
||||||
|
break
|
||||||
|
|
||||||
|
self.n_past += 1
|
||||||
|
self.n_session_consumed += 1
|
||||||
|
|
||||||
|
if self.n_session_consumed >= len(self.session_tokens):
|
||||||
|
i += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
self.embd = self.embd[i:]
|
||||||
|
|
||||||
|
# evaluate tokens in batches
|
||||||
|
# embd is typically prepared beforehand to fit within a batch, but not always
|
||||||
|
#TODO BUG: The batching code causes nonsensical generation
|
||||||
|
"""for i in range(0, len(self.embd), self.params.n_batch):
|
||||||
|
n_eval = self.params.n_batch
|
||||||
|
_arr = (llama_cpp.llama_token * n_eval)(*self.embd[i:i + n_eval])
|
||||||
|
if llama_cpp.llama_eval(self.ctx, _arr, n_eval, self.n_past, self.params.n_threads) != 0:
|
||||||
|
print(f"failed to eval")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.n_past += n_eval"""
|
||||||
|
|
||||||
if (llama_cpp.llama_eval(
|
if (llama_cpp.llama_eval(
|
||||||
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
|
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
|
||||||
) != 0):
|
) != 0):
|
||||||
raise Exception("Failed to llama_eval!")
|
raise Exception("Failed to llama_eval!")
|
||||||
|
|
||||||
|
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
|
||||||
|
self.session_tokens.extend(self.embd)
|
||||||
|
self.n_session_consumed = len(self.session_tokens)
|
||||||
|
|
||||||
self.n_past += len(self.embd)
|
self.n_past += len(self.embd)
|
||||||
self.embd = []
|
self.embd = []
|
||||||
if len(self.embd_inp) <= self.input_consumed:
|
if len(self.embd_inp) <= self.input_consumed: #&& !is_interacting
|
||||||
# out of user input, sample next token
|
# out of user input, sample next token
|
||||||
|
top_k = llama_cpp.llama_n_vocab(self.ctx) if self.params.top_k <= 0 else self.params.top_k
|
||||||
|
repeat_last_n = self.n_ctx if self.params.repeat_last_n < 0 else self.params.repeat_last_n
|
||||||
|
|
||||||
if (self.params.ignore_eos):
|
# optionally save the session on first sample (for faster prompt loading next time)
|
||||||
logits = llama_cpp.llama_get_logits(self.ctx)
|
if len(self.params.path_session) > 0 and self.need_to_save_session:
|
||||||
logits[llama_cpp.llama_token_eos()] = llama_cpp.c_float(0)
|
self.need_to_save_session = False
|
||||||
|
llama_cpp.llama_save_session_file(
|
||||||
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
|
|
||||||
id = llama_cpp.llama_sample_top_p_top_k(
|
|
||||||
self.ctx,
|
self.ctx,
|
||||||
(llama_cpp.llama_token * len(_arr))(*_arr),
|
self.params.path_session.encode("utf8"),
|
||||||
len(_arr),
|
self.session_tokens,
|
||||||
self.params.top_k,
|
len(self.session_tokens)
|
||||||
self.params.top_p,
|
|
||||||
self.params.temp,
|
|
||||||
self.params.repeat_penalty,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
id = 0
|
||||||
|
|
||||||
|
logits = llama_cpp.llama_get_logits(self.ctx)
|
||||||
|
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||||
|
|
||||||
|
# Apply params.logit_bias map
|
||||||
|
for key, value in self.params.logit_bias.items():
|
||||||
|
logits[key] += value
|
||||||
|
|
||||||
|
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
||||||
|
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||||
|
for token_id in range(n_vocab)
|
||||||
|
])
|
||||||
|
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||||
|
|
||||||
|
# Apply penalties
|
||||||
|
nl_logit = logits[llama_cpp.llama_token_nl()]
|
||||||
|
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
|
||||||
|
|
||||||
|
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
||||||
|
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
|
||||||
|
_arr,
|
||||||
|
last_n_repeat, self.params.repeat_penalty)
|
||||||
|
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
|
||||||
|
_arr,
|
||||||
|
last_n_repeat, self.params.frequency_penalty, self.params.presence_penalty)
|
||||||
|
|
||||||
|
if not self.params.penalize_nl:
|
||||||
|
logits[llama_cpp.llama_token_nl()] = nl_logit
|
||||||
|
|
||||||
|
if self.params.temp <= 0:
|
||||||
|
# Greedy sampling
|
||||||
|
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
|
||||||
|
else:
|
||||||
|
if self.params.mirostat == 1:
|
||||||
|
mirostat_mu = 2.0 * self.params.mirostat_tau
|
||||||
|
mirostat_m = 100
|
||||||
|
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||||
|
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, mirostat_mu)
|
||||||
|
elif self.params.mirostat == 2:
|
||||||
|
mirostat_mu = 2.0 * self.params.mirostat_tau
|
||||||
|
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||||
|
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_mu)
|
||||||
|
else:
|
||||||
|
# Temperature sampling
|
||||||
|
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k)
|
||||||
|
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, self.params.tfs_z)
|
||||||
|
llama_cpp.llama_sample_typical(self.ctx, candidates_p, self.params.typical_p)
|
||||||
|
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, self.params.top_p)
|
||||||
|
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||||
|
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
|
||||||
|
# print("`{}`".format(candidates_p.size))
|
||||||
|
|
||||||
self.last_n_tokens.pop(0)
|
self.last_n_tokens.pop(0)
|
||||||
self.last_n_tokens.append(id)
|
self.last_n_tokens.append(id)
|
||||||
|
|
||||||
|
@ -288,7 +437,7 @@ n_keep = {self.params.n_keep}
|
||||||
# display tokens
|
# display tokens
|
||||||
if self.output_echo:
|
if self.output_echo:
|
||||||
for id in self.embd:
|
for id in self.embd:
|
||||||
if self.params.instruct:
|
if self.antiecho != None:
|
||||||
for r in self.antiecho(id):
|
for r in self.antiecho(id):
|
||||||
yield r
|
yield r
|
||||||
else:
|
else:
|
||||||
|
@ -356,7 +505,7 @@ n_keep = {self.params.n_keep}
|
||||||
def output(self):
|
def output(self):
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
for id in self.generate():
|
for id in self.generate():
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
|
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
||||||
|
|
||||||
# read user input
|
# read user input
|
||||||
def read_input(self):
|
def read_input(self):
|
||||||
|
@ -415,8 +564,7 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||||
{USER_NAME}: Name a color.
|
{USER_NAME}: Name a color.
|
||||||
{AI_NAME}: Blue
|
{AI_NAME}: Blue
|
||||||
{USER_NAME}:"""
|
{USER_NAME}:"""
|
||||||
args = gpt_params_parse()
|
params = gpt_params_parse()
|
||||||
params = GptParams(**vars(args))
|
|
||||||
|
|
||||||
with LLaMAInteract(params) as m:
|
with LLaMAInteract(params) as m:
|
||||||
m.interact()
|
m.interact()
|
||||||
|
|
|
@ -37,6 +37,10 @@ embd = []
|
||||||
last_n_size = 64
|
last_n_size = 64
|
||||||
last_n_tokens_data = [0] * last_n_size
|
last_n_tokens_data = [0] * last_n_size
|
||||||
n_batch = 24
|
n_batch = 24
|
||||||
|
last_n_repeat = 64
|
||||||
|
repeat_penalty = 1
|
||||||
|
frequency_penalty = 0.0
|
||||||
|
presence_penalty = 0.0
|
||||||
|
|
||||||
while remaining_tokens > 0:
|
while remaining_tokens > 0:
|
||||||
if len(embd) > 0:
|
if len(embd) > 0:
|
||||||
|
@ -47,15 +51,28 @@ while remaining_tokens > 0:
|
||||||
n_past += len(embd)
|
n_past += len(embd)
|
||||||
embd = []
|
embd = []
|
||||||
if len(embd_inp) <= input_consumed:
|
if len(embd_inp) <= input_consumed:
|
||||||
id = llama_cpp.llama_sample_top_p_top_k(
|
logits = llama_cpp.llama_get_logits(ctx)
|
||||||
ctx,
|
n_vocab = llama_cpp.llama_n_vocab(ctx)
|
||||||
(llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data),
|
|
||||||
len(last_n_tokens_data),
|
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
||||||
40,
|
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||||
0.8,
|
for token_id in range(n_vocab)
|
||||||
0.2,
|
])
|
||||||
1.0 / 0.85,
|
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
|
||||||
)
|
|
||||||
|
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
|
||||||
|
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
|
||||||
|
_arr,
|
||||||
|
last_n_repeat, repeat_penalty)
|
||||||
|
llama_cpp.llama_sample_frequency_and_presence_penalties(ctx, candidates_p,
|
||||||
|
_arr,
|
||||||
|
last_n_repeat, frequency_penalty, presence_penalty)
|
||||||
|
|
||||||
|
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
|
||||||
|
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
|
||||||
|
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
|
||||||
|
id = llama_cpp.llama_sample_token(ctx, candidates_p)
|
||||||
|
|
||||||
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
last_n_tokens_data = last_n_tokens_data[1:] + [id]
|
||||||
embd.append(id)
|
embd.append(id)
|
||||||
input_noecho = False
|
input_noecho = False
|
||||||
|
|
Loading…
Reference in a new issue