Fix stop sequence performance bug.

This commit is contained in:
Andrei Betlen 2023-05-26 20:23:49 -04:00
parent 00ea3af51b
commit 8f35bddd7e
2 changed files with 11 additions and 5 deletions

View file

@ -9,4 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Added first version of the changelog - Added first version of the changelog
### Fixed
- Performance bug in stop sequence check slowing down streaming.

View file

@ -775,20 +775,22 @@ class Llama:
break break
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
remaining_length = len(remaining_text)
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
# the generated text if they are part of a stop # the generated text if they are part of a stop
# sequence. # sequence.
first_stop_position = 0 first_stop_position = 0
for s in stop_sequences: for s in stop_sequences:
for i in range(len(s), 0, -1): for i in range(min(len(s), remaining_length), 0, -1):
if all_text.endswith(s[:i]): if remaining_text.endswith(s[:i]):
if i > first_stop_position: if i > first_stop_position:
first_stop_position = i first_stop_position = i
break break
token_end_position = 0 token_end_position = 0
remaining_tokens = completion_tokens[returned_tokens:]
remaining_length = len(self.detokenize(remaining_tokens))
for token in remaining_tokens: for token in remaining_tokens:
token_end_position += len(self.detokenize([token])) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token # Check if stop sequence is in the token