llama.cpp/examples/high_level_api/langchain_custom_llm.py

56 lines
1.5 KiB
Python
Raw Normal View History

2023-03-23 23:12:42 -04:00
import argparse
2023-03-23 16:25:24 -04:00
from llama_cpp import Llama
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
2023-03-24 14:35:41 -04:00
2023-03-23 16:25:24 -04:00
class LlamaLLM(LLM):
model_path: str
llm: Llama
@property
def _llm_type(self) -> str:
return "llama-cpp-python"
def __init__(self, model_path: str, **kwargs: Any):
model_path = model_path
llm = Llama(model_path=model_path)
2023-03-24 14:35:41 -04:00
super().__init__(model_path=model_path, llm=llm, **kwargs)
2023-03-23 16:25:24 -04:00
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = self.llm(prompt, stop=stop or [])
return response["choices"][0]["text"]
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"model_path": self.model_path}
2023-03-24 14:35:41 -04:00
2023-03-23 23:12:42 -04:00
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
2023-03-23 23:12:42 -04:00
args = parser.parse_args()
# Load the model
llm = LlamaLLM(model_path=args.model)
# Basic Q&A
2023-03-24 14:35:41 -04:00
answer = llm(
"Question: What is the capital of France? Answer: ", stop=["Question:", "\n"]
)
2023-03-23 23:12:42 -04:00
print(f"Answer: {answer.strip()}")
# Using in a chain
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
prompt = PromptTemplate(
input_variables=["product"],
template="\n\n### Instruction:\nWrite a good name for a company that makes {product}\n\n### Response:\n",
)
chain = LLMChain(llm=llm, prompt=prompt)
2023-03-23 16:25:24 -04:00
2023-03-23 23:12:42 -04:00
# Run the chain only specifying the input variable.
2023-03-24 14:35:41 -04:00
print(chain.run("colorful socks"))