fix: Incorporate embedding pooling layer fixes (#1194)
* remove division by token count * truncate to n_batch, not n_ctx
This commit is contained in:
parent
ae71ad1a14
commit
7bb91f025f
1 changed files with 17 additions and 16 deletions
|
@ -762,7 +762,7 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
n_embd = self.n_embd()
|
n_embd = self.n_embd()
|
||||||
n_ctx = self.n_ctx()
|
n_batch = self.n_batch
|
||||||
|
|
||||||
if self.context_params.embedding == False:
|
if self.context_params.embedding == False:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -782,54 +782,55 @@ class Llama:
|
||||||
|
|
||||||
# decode and fetch embeddings
|
# decode and fetch embeddings
|
||||||
data: List[List[float]] = []
|
data: List[List[float]] = []
|
||||||
def decode_batch(sizes: List[int]):
|
def decode_batch(n_seq: int):
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||||
self._ctx.decode(self._batch)
|
self._ctx.decode(self._batch)
|
||||||
self._batch.reset()
|
self._batch.reset()
|
||||||
|
|
||||||
# store embeddings
|
# store embeddings
|
||||||
for i, s in enumerate(sizes):
|
for i in range(n_seq):
|
||||||
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||||
:n_embd
|
:n_embd
|
||||||
]
|
]
|
||||||
norm = np.linalg.norm(embedding) if normalize else s
|
if normalize:
|
||||||
embedding: List[float] = [v / float(norm) for v in embedding]
|
norm = float(np.linalg.norm(embedding))
|
||||||
|
embedding = [v / norm for v in embedding]
|
||||||
data.append(embedding)
|
data.append(embedding)
|
||||||
|
|
||||||
# init state
|
# init state
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
t_batch = 0
|
t_batch = 0
|
||||||
s_sizes: List[int] = []
|
p_batch = 0
|
||||||
|
|
||||||
# accumulate batches and encode
|
# accumulate batches and encode
|
||||||
for text in inputs:
|
for text in inputs:
|
||||||
tokens = self.tokenize(text.encode("utf-8"))
|
tokens = self.tokenize(text.encode("utf-8"))
|
||||||
if truncate:
|
if truncate:
|
||||||
tokens = tokens[:n_ctx]
|
tokens = tokens[:n_batch]
|
||||||
|
|
||||||
n_tokens = len(tokens)
|
n_tokens = len(tokens)
|
||||||
total_tokens += n_tokens
|
total_tokens += n_tokens
|
||||||
|
|
||||||
# check for overrun
|
# check for overrun
|
||||||
if n_tokens > n_ctx:
|
if n_tokens > n_batch:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
|
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# time to eval batch
|
# time to eval batch
|
||||||
if t_batch + n_tokens > self._n_ctx:
|
if t_batch + n_tokens > n_batch:
|
||||||
decode_batch(s_sizes)
|
decode_batch(p_batch)
|
||||||
t_batch = 0
|
t_batch = 0
|
||||||
s_sizes = []
|
p_batch = 0
|
||||||
|
|
||||||
# add to batch
|
# add to batch
|
||||||
self._batch.add_sequence(tokens, len(s_sizes), False)
|
self._batch.add_sequence(tokens, p_batch, False)
|
||||||
t_batch += n_tokens
|
t_batch += n_tokens
|
||||||
s_sizes.append(n_tokens)
|
p_batch += 1
|
||||||
|
|
||||||
# hanlde last batch
|
# hanlde last batch
|
||||||
decode_batch(s_sizes)
|
decode_batch(p_batch)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_print_timings(self._ctx.ctx)
|
llama_cpp.llama_print_timings(self._ctx.ctx)
|
||||||
|
|
Loading…
Reference in a new issue