fix: Incorporate embedding pooling layer fixes (#1194)

* remove division by token count

* truncate to n_batch, not n_ctx
This commit is contained in:
Douglas Hanley 2024-02-15 14:16:30 -06:00 committed by GitHub
parent ae71ad1a14
commit 7bb91f025f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -762,7 +762,7 @@ class Llama:
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()
n_batch = self.n_batch
if self.context_params.embedding == False:
raise RuntimeError(
@ -782,54 +782,55 @@ class Llama:
# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
def decode_batch(n_seq: int):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()
# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
data.append(embedding)
# init state
total_tokens = 0
t_batch = 0
s_sizes: List[int] = []
p_batch = 0
# accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]
tokens = tokens[:n_batch]
n_tokens = len(tokens)
total_tokens += n_tokens
# check for overrun
if n_tokens > n_ctx:
if n_tokens > n_batch:
raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
)
# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
if t_batch + n_tokens > n_batch:
decode_batch(p_batch)
t_batch = 0
s_sizes = []
p_batch = 0
# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
self._batch.add_sequence(tokens, p_batch, False)
t_batch += n_tokens
s_sizes.append(n_tokens)
p_batch += 1
# hanlde last batch
decode_batch(s_sizes)
decode_batch(p_batch)
if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)