fix: Incorporate embedding pooling layer fixes (#1194)

* remove division by token count

* truncate to n_batch, not n_ctx
This commit is contained in:
Douglas Hanley 2024-02-15 14:16:30 -06:00 committed by GitHub
parent ae71ad1a14
commit 7bb91f025f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -762,7 +762,7 @@ class Llama:
""" """
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
n_embd = self.n_embd() n_embd = self.n_embd()
n_ctx = self.n_ctx() n_batch = self.n_batch
if self.context_params.embedding == False: if self.context_params.embedding == False:
raise RuntimeError( raise RuntimeError(
@ -782,54 +782,55 @@ class Llama:
# decode and fetch embeddings # decode and fetch embeddings
data: List[List[float]] = [] data: List[List[float]] = []
def decode_batch(sizes: List[int]): def decode_batch(n_seq: int):
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx) llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch) self._ctx.decode(self._batch)
self._batch.reset() self._batch.reset()
# store embeddings # store embeddings
for i, s in enumerate(sizes): for i in range(n_seq):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd :n_embd
] ]
norm = np.linalg.norm(embedding) if normalize else s if normalize:
embedding: List[float] = [v / float(norm) for v in embedding] norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
data.append(embedding) data.append(embedding)
# init state # init state
total_tokens = 0 total_tokens = 0
t_batch = 0 t_batch = 0
s_sizes: List[int] = [] p_batch = 0
# accumulate batches and encode # accumulate batches and encode
for text in inputs: for text in inputs:
tokens = self.tokenize(text.encode("utf-8")) tokens = self.tokenize(text.encode("utf-8"))
if truncate: if truncate:
tokens = tokens[:n_ctx] tokens = tokens[:n_batch]
n_tokens = len(tokens) n_tokens = len(tokens)
total_tokens += n_tokens total_tokens += n_tokens
# check for overrun # check for overrun
if n_tokens > n_ctx: if n_tokens > n_batch:
raise ValueError( raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}" f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}"
) )
# time to eval batch # time to eval batch
if t_batch + n_tokens > self._n_ctx: if t_batch + n_tokens > n_batch:
decode_batch(s_sizes) decode_batch(p_batch)
t_batch = 0 t_batch = 0
s_sizes = [] p_batch = 0
# add to batch # add to batch
self._batch.add_sequence(tokens, len(s_sizes), False) self._batch.add_sequence(tokens, p_batch, False)
t_batch += n_tokens t_batch += n_tokens
s_sizes.append(n_tokens) p_batch += 1
# hanlde last batch # hanlde last batch
decode_batch(s_sizes) decode_batch(p_batch)
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx) llama_cpp.llama_print_timings(self._ctx.ctx)