diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5bcfad8..87f69ce 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -262,6 +262,17 @@ class Llama: }, } + def embed(self, input: str) -> List[float]: + """Embed a string. + + Args: + input: The utf-8 encoded string to embed. + + Returns: + A list of embeddings + """ + return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) + def _create_completion( self, prompt: str, @@ -341,7 +352,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[start :].decode("utf-8"), + "text": text[start:].decode("utf-8"), "index": 0, "logprobs": None, "finish_reason": None,