fix: Incorporate embedding pooling layer fixes (#1194)
* remove division by token count * truncate to n_batch, not n_ctx
This commit is contained in:
parent
ae71ad1a14
commit
7bb91f025f
1 changed files with 17 additions and 16 deletions
|
@ -762,7 +762,7 @@ class Llama:
|
|||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
n_embd = self.n_embd()
|
||||
n_ctx = self.n_ctx()
|
||||
n_batch = self.n_batch
|
||||
|
||||
if self.context_params.embedding == False:
|
||||
raise RuntimeError(
|
||||
|
@ -782,54 +782,55 @@ class Llama:
|
|||
|
||||
# decode and fetch embeddings
|
||||
data: List[List[float]] = []
|
||||
def decode_batch(sizes: List[int]):
|
||||
def decode_batch(n_seq: int):
|
||||
assert self._ctx.ctx is not None
|
||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||
self._ctx.decode(self._batch)
|
||||
self._batch.reset()
|
||||
|
||||
# store embeddings
|
||||
for i, s in enumerate(sizes):
|
||||
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||
for i in range(n_seq):
|
||||
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||
:n_embd
|
||||
]
|
||||
norm = np.linalg.norm(embedding) if normalize else s
|
||||
embedding: List[float] = [v / float(norm) for v in embedding]
|
||||
if normalize:
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
embedding = [v / norm for v in embedding]
|
||||
data.append(embedding)
|
||||
|
||||
# init state
|
||||
total_tokens = 0
|
||||
t_batch = 0
|
||||
s_sizes: List[int] = []
|
||||
p_batch = 0
|
||||
|
||||
# accumulate batches and encode
|
||||
for text in inputs:
|
||||
tokens = self.tokenize(text.encode("utf-8"))
|
||||
if truncate:
|
||||
tokens = tokens[:n_ctx]
|
||||
tokens = tokens[:n_batch]
|
||||
|
||||
n_tokens = len(tokens)
|
||||
total_tokens += n_tokens
|
||||
|
||||
# check for overrun
|
||||
if n_tokens > n_ctx:
|
||||
if n_tokens > n_batch:
|
||||
raise ValueError(
|
||||
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
|
||||
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
|
||||
)
|
||||
|
||||
# time to eval batch
|
||||
if t_batch + n_tokens > self._n_ctx:
|
||||
decode_batch(s_sizes)
|
||||
if t_batch + n_tokens > n_batch:
|
||||
decode_batch(p_batch)
|
||||
t_batch = 0
|
||||
s_sizes = []
|
||||
p_batch = 0
|
||||
|
||||
# add to batch
|
||||
self._batch.add_sequence(tokens, len(s_sizes), False)
|
||||
self._batch.add_sequence(tokens, p_batch, False)
|
||||
t_batch += n_tokens
|
||||
s_sizes.append(n_tokens)
|
||||
p_batch += 1
|
||||
|
||||
# hanlde last batch
|
||||
decode_batch(s_sizes)
|
||||
decode_batch(p_batch)
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self._ctx.ctx)
|
||||
|
|
Loading…
Reference in a new issue