From 70b8a1ef75ca5cd5f5a525525c89a8d29b2e2a28 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 28 Mar 2023 04:59:54 -0400 Subject: [PATCH] Add support to get embeddings from high-level api. Closes #4 --- examples/high_level_api_embedding.py | 12 ++++++++++++ llama_cpp/llama.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 examples/high_level_api_embedding.py diff --git a/examples/high_level_api_embedding.py b/examples/high_level_api_embedding.py new file mode 100644 index 0000000..9b10f7f --- /dev/null +++ b/examples/high_level_api_embedding.py @@ -0,0 +1,12 @@ +import json +import argparse + +from llama_cpp import Llama + +parser = argparse.ArgumentParser() +parser.add_argument("-m", "--model", type=str, default=".//models/...") +args = parser.parse_args() + +llm = Llama(model_path=args.model, embedding=True) + +print(llm.embed("Hello world!")) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 268c27d..cc941e9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -105,6 +105,20 @@ class Llama: output += llama_cpp.llama_token_to_str(self.ctx, token) return output + def embed(self, text: str): + """Embed a string. + + Args: + text: The utf-8 encoded string to embed. + + Returns: + A list of embeddings. + """ + tokens = self.tokenize(text.encode("utf-8")) + self._eval(tokens, 0) + embeddings = llama_cpp.llama_get_embeddings(self.ctx) + return embeddings[:llama_cpp.llama_n_embd(self.ctx)] + def _eval(self, tokens: List[int], n_past): rc = llama_cpp.llama_eval( self.ctx,