Black formatting

This commit is contained in:
Andrei Betlen 2023-03-24 14:35:41 -04:00
parent d29b05bb67
commit 2cc499512c
6 changed files with 121 additions and 35 deletions

View file

@ -5,9 +5,11 @@ from llama_cpp import Llama
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel, BaseSettings, Field from pydantic import BaseModel, BaseSettings, Field
class Settings(BaseSettings): class Settings(BaseSettings):
model: str model: str
app = FastAPI( app = FastAPI(
title="🦙 llama.cpp Python API", title="🦙 llama.cpp Python API",
version="0.0.1", version="0.0.1",
@ -15,6 +17,7 @@ app = FastAPI(
settings = Settings() settings = Settings()
llama = Llama(settings.model) llama = Llama(settings.model)
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
prompt: str prompt: str
suffix: Optional[str] = Field(None) suffix: Optional[str] = Field(None)
@ -31,12 +34,11 @@ class CompletionRequest(BaseModel):
schema_extra = { schema_extra = {
"example": { "example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"] "stop": ["\n", "###"],
} }
} }
@app.post("/v1/completions") @app.post("/v1/completions")
def completions(request: CompletionRequest): def completions(request: CompletionRequest):
return llama(**request.dict()) return llama(**request.dict())

View file

@ -9,6 +9,11 @@ args = parser.parse_args()
llm = Llama(model_path=args.model) llm = Llama(model_path=args.model)
output = llm("Question: What are the names of the planets in the solar system? Answer: ", max_tokens=48, stop=["Q:", "\n"], echo=True) output = llm(
"Question: What are the names of the planets in the solar system? Answer: ",
max_tokens=48,
stop=["Q:", "\n"],
echo=True,
)
print(json.dumps(output, indent=2)) print(json.dumps(output, indent=2))

View file

@ -5,6 +5,7 @@ from llama_cpp import Llama
from langchain.llms.base import LLM from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any from typing import Optional, List, Mapping, Any
class LlamaLLM(LLM): class LlamaLLM(LLM):
model_path: str model_path: str
llm: Llama llm: Llama
@ -16,7 +17,7 @@ class LlamaLLM(LLM):
def __init__(self, model_path: str, **kwargs: Any): def __init__(self, model_path: str, **kwargs: Any):
model_path = model_path model_path = model_path
llm = Llama(model_path=model_path) llm = Llama(model_path=model_path)
super().__init__(model_path=model_path, llm=llm, **kwargs) super().__init__(model_path=model_path, llm=llm, **kwargs)
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = self.llm(prompt, stop=stop or []) response = self.llm(prompt, stop=stop or [])
@ -26,6 +27,7 @@ class LlamaLLM(LLM):
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {"model_path": self.model_path} return {"model_path": self.model_path}
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="./models/...") parser.add_argument("-m", "--model", type=str, default="./models/...")
args = parser.parse_args() args = parser.parse_args()
@ -34,7 +36,9 @@ args = parser.parse_args()
llm = LlamaLLM(model_path=args.model) llm = LlamaLLM(model_path=args.model)
# Basic Q&A # Basic Q&A
answer = llm("Question: What is the capital of France? Answer: ", stop=["Question:", "\n"]) answer = llm(
"Question: What is the capital of France? Answer: ", stop=["Question:", "\n"]
)
print(f"Answer: {answer.strip()}") print(f"Answer: {answer.strip()}")
# Using in a chain # Using in a chain
@ -48,4 +52,4 @@ prompt = PromptTemplate(
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable. # Run the chain only specifying the input variable.
print(chain.run("colorful socks")) print(chain.run("colorful socks"))

View file

@ -27,7 +27,15 @@ embd = embd_inp
n = 8 n = 8
for i in range(n): for i in range(n):
id = llama_cpp.llama_sample_top_p_top_k(ctx, (llama_cpp.c_int * len(embd))(*embd), n_of_tok + i, 40, 0.8, 0.2, 1.0/0.85) id = llama_cpp.llama_sample_top_p_top_k(
ctx,
(llama_cpp.c_int * len(embd))(*embd),
n_of_tok + i,
40,
0.8,
0.2,
1.0 / 0.85,
)
embd.append(id) embd.append(id)
@ -38,4 +46,4 @@ for i in range(n):
llama_cpp.llama_free(ctx) llama_cpp.llama_free(ctx)
print(prediction.decode("utf-8")) print(prediction.decode("utf-8"))

View file

@ -5,6 +5,7 @@ from typing import List, Optional
from . import llama_cpp from . import llama_cpp
class Llama: class Llama:
def __init__( def __init__(
self, self,
@ -82,7 +83,10 @@ class Llama:
for i in range(max_tokens): for i in range(max_tokens):
tokens_seen = prompt_tokens + completion_tokens tokens_seen = prompt_tokens + completion_tokens
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)] last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [
self.tokens[j]
for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)
]
token = llama_cpp.llama_sample_top_p_top_k( token = llama_cpp.llama_sample_top_p_top_k(
self.ctx, self.ctx,
@ -128,9 +132,8 @@ class Llama:
self.ctx, self.ctx,
)[:logprobs] )[:logprobs]
return { return {
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change "id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": self.model_path, "model": self.model_path,
@ -151,5 +154,3 @@ class Llama:
def __del__(self): def __del__(self):
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)

View file

@ -1,6 +1,15 @@
import ctypes import ctypes
from ctypes import c_int, c_float, c_double, c_char_p, c_void_p, c_bool, POINTER, Structure from ctypes import (
c_int,
c_float,
c_double,
c_char_p,
c_void_p,
c_bool,
POINTER,
Structure,
)
import pathlib import pathlib
@ -13,26 +22,32 @@ lib = ctypes.CDLL(str(libfile))
llama_token = c_int llama_token = c_int
llama_token_p = POINTER(llama_token) llama_token_p = POINTER(llama_token)
class llama_token_data(Structure): class llama_token_data(Structure):
_fields_ = [ _fields_ = [
('id', llama_token), # token id ("id", llama_token), # token id
('p', c_float), # probability of the token ("p", c_float), # probability of the token
('plog', c_float), # log probability of the token ("plog", c_float), # log probability of the token
] ]
llama_token_data_p = POINTER(llama_token_data) llama_token_data_p = POINTER(llama_token_data)
class llama_context_params(Structure): class llama_context_params(Structure):
_fields_ = [ _fields_ = [
('n_ctx', c_int), # text context ("n_ctx", c_int), # text context
('n_parts', c_int), # -1 for default ("n_parts", c_int), # -1 for default
('seed', c_int), # RNG seed, 0 for random ("seed", c_int), # RNG seed, 0 for random
('f16_kv', c_bool), # use fp16 for KV cache ("f16_kv", c_bool), # use fp16 for KV cache
('logits_all', c_bool), # the llama_eval() call computes all logits, not just the last one (
"logits_all",
('vocab_only', c_bool), # only load the vocabulary, no weights c_bool,
), # the llama_eval() call computes all logits, not just the last one
("vocab_only", c_bool), # only load the vocabulary, no weights
] ]
llama_context_params_p = POINTER(llama_context_params) llama_context_params_p = POINTER(llama_context_params)
llama_context_p = c_void_p llama_context_p = c_void_p
@ -74,7 +89,15 @@ lib.llama_token_bos.restype = llama_token
lib.llama_token_eos.argtypes = [] lib.llama_token_eos.argtypes = []
lib.llama_token_eos.restype = llama_token lib.llama_token_eos.restype = llama_token
lib.llama_sample_top_p_top_k.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_double, c_double, c_double] lib.llama_sample_top_p_top_k.argtypes = [
llama_context_p,
llama_token_p,
c_int,
c_int,
c_double,
c_double,
c_double,
]
lib.llama_sample_top_p_top_k.restype = llama_token lib.llama_sample_top_p_top_k.restype = llama_token
lib.llama_print_timings.argtypes = [llama_context_p] lib.llama_print_timings.argtypes = [llama_context_p]
@ -86,45 +109,71 @@ lib.llama_reset_timings.restype = None
lib.llama_print_system_info.argtypes = [] lib.llama_print_system_info.argtypes = []
lib.llama_print_system_info.restype = c_char_p lib.llama_print_system_info.restype = c_char_p
# Python functions # Python functions
def llama_context_default_params() -> llama_context_params: def llama_context_default_params() -> llama_context_params:
params = lib.llama_context_default_params() params = lib.llama_context_default_params()
return params return params
def llama_init_from_file(path_model: bytes, params: llama_context_params) -> llama_context_p:
def llama_init_from_file(
path_model: bytes, params: llama_context_params
) -> llama_context_p:
"""Various functions for loading a ggml llama model. """Various functions for loading a ggml llama model.
Allocate (almost) all memory needed for the model. Allocate (almost) all memory needed for the model.
Return NULL on failure """ Return NULL on failure"""
return lib.llama_init_from_file(path_model, params) return lib.llama_init_from_file(path_model, params)
def llama_free(ctx: llama_context_p): def llama_free(ctx: llama_context_p):
"""Free all allocated memory""" """Free all allocated memory"""
lib.llama_free(ctx) lib.llama_free(ctx)
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int) -> c_int:
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int
) -> c_int:
"""Returns 0 on success""" """Returns 0 on success"""
return lib.llama_model_quantize(fname_inp, fname_out, itype, qk) return lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
def llama_eval(ctx: llama_context_p, tokens: llama_token_p, n_tokens: c_int, n_past: c_int, n_threads: c_int) -> c_int:
def llama_eval(
ctx: llama_context_p,
tokens: llama_token_p,
n_tokens: c_int,
n_past: c_int,
n_threads: c_int,
) -> c_int:
"""Run the llama inference to obtain the logits and probabilities for the next token. """Run the llama inference to obtain the logits and probabilities for the next token.
tokens + n_tokens is the provided batch of new tokens to process tokens + n_tokens is the provided batch of new tokens to process
n_past is the number of tokens to use from previous eval calls n_past is the number of tokens to use from previous eval calls
Returns 0 on success""" Returns 0 on success"""
return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
def llama_tokenize(ctx: llama_context_p, text: bytes, tokens: llama_token_p, n_max_tokens: c_int, add_bos: c_bool) -> c_int:
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
tokens: llama_token_p,
n_max_tokens: c_int,
add_bos: c_bool,
) -> c_int:
"""Convert the provided text into tokens. """Convert the provided text into tokens.
The tokens pointer must be large enough to hold the resulting tokens. The tokens pointer must be large enough to hold the resulting tokens.
Returns the number of tokens on success, no more than n_max_tokens Returns the number of tokens on success, no more than n_max_tokens
Returns a negative number on failure - the number of tokens that would have been returned""" Returns a negative number on failure - the number of tokens that would have been returned
"""
return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos) return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
def llama_n_vocab(ctx: llama_context_p) -> c_int: def llama_n_vocab(ctx: llama_context_p) -> c_int:
return lib.llama_n_vocab(ctx) return lib.llama_n_vocab(ctx)
def llama_n_ctx(ctx: llama_context_p) -> c_int: def llama_n_ctx(ctx: llama_context_p) -> c_int:
return lib.llama_n_ctx(ctx) return lib.llama_n_ctx(ctx)
def llama_get_logits(ctx: llama_context_p): def llama_get_logits(ctx: llama_context_p):
"""Token logits obtained from the last call to llama_eval() """Token logits obtained from the last call to llama_eval()
The logits for the last token are stored in the last row The logits for the last token are stored in the last row
@ -133,25 +182,42 @@ def llama_get_logits(ctx: llama_context_p):
Cols: n_vocab""" Cols: n_vocab"""
return lib.llama_get_logits(ctx) return lib.llama_get_logits(ctx)
def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes: def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes:
"""Token Id -> String. Uses the vocabulary in the provided context""" """Token Id -> String. Uses the vocabulary in the provided context"""
return lib.llama_token_to_str(ctx, token) return lib.llama_token_to_str(ctx, token)
def llama_token_bos() -> llama_token: def llama_token_bos() -> llama_token:
return lib.llama_token_bos() return lib.llama_token_bos()
def llama_token_eos() -> llama_token: def llama_token_eos() -> llama_token:
return lib.llama_token_eos() return lib.llama_token_eos()
def llama_sample_top_p_top_k(ctx: llama_context_p, last_n_tokens_data: llama_token_p, last_n_tokens_size: c_int, top_k: c_int, top_p: c_double, temp: c_double, repeat_penalty: c_double) -> llama_token:
return lib.llama_sample_top_p_top_k(ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty) def llama_sample_top_p_top_k(
ctx: llama_context_p,
last_n_tokens_data: llama_token_p,
last_n_tokens_size: c_int,
top_k: c_int,
top_p: c_double,
temp: c_double,
repeat_penalty: c_double,
) -> llama_token:
return lib.llama_sample_top_p_top_k(
ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty
)
def llama_print_timings(ctx: llama_context_p): def llama_print_timings(ctx: llama_context_p):
lib.llama_print_timings(ctx) lib.llama_print_timings(ctx)
def llama_reset_timings(ctx: llama_context_p): def llama_reset_timings(ctx: llama_context_p):
lib.llama_reset_timings(ctx) lib.llama_reset_timings(ctx)
def llama_print_system_info() -> bytes: def llama_print_system_info() -> bytes:
"""Print system informaiton""" """Print system informaiton"""
return lib.llama_print_system_info() return lib.llama_print_system_info()