diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3e09a20..f3c7b4f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -762,7 +762,7 @@ class Llama: """ assert self._ctx.ctx is not None n_embd = self.n_embd() - n_ctx = self.n_ctx() + n_batch = self.n_batch if self.context_params.embedding == False: raise RuntimeError( @@ -782,54 +782,55 @@ class Llama: # decode and fetch embeddings data: List[List[float]] = [] - def decode_batch(sizes: List[int]): + def decode_batch(n_seq: int): assert self._ctx.ctx is not None llama_cpp.llama_kv_cache_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() # store embeddings - for i, s in enumerate(sizes): - embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ + for i in range(n_seq): + embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ :n_embd ] - norm = np.linalg.norm(embedding) if normalize else s - embedding: List[float] = [v / float(norm) for v in embedding] + if normalize: + norm = float(np.linalg.norm(embedding)) + embedding = [v / norm for v in embedding] data.append(embedding) # init state total_tokens = 0 t_batch = 0 - s_sizes: List[int] = [] + p_batch = 0 # accumulate batches and encode for text in inputs: tokens = self.tokenize(text.encode("utf-8")) if truncate: - tokens = tokens[:n_ctx] + tokens = tokens[:n_batch] n_tokens = len(tokens) total_tokens += n_tokens # check for overrun - if n_tokens > n_ctx: + if n_tokens > n_batch: raise ValueError( - f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}" + f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}" ) # time to eval batch - if t_batch + n_tokens > self._n_ctx: - decode_batch(s_sizes) + if t_batch + n_tokens > n_batch: + decode_batch(p_batch) t_batch = 0 - s_sizes = [] + p_batch = 0 # add to batch - self._batch.add_sequence(tokens, len(s_sizes), False) + self._batch.add_sequence(tokens, p_batch, False) t_batch += n_tokens - s_sizes.append(n_tokens) + p_batch += 1 # hanlde last batch - decode_batch(s_sizes) + decode_batch(p_batch) if self.verbose: llama_cpp.llama_print_timings(self._ctx.ctx)