Fix stop sequence performance bug.
This commit is contained in:
parent
00ea3af51b
commit
8f35bddd7e
2 changed files with 11 additions and 5 deletions
|
@ -10,3 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Added first version of the changelog
|
- Added first version of the changelog
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Performance bug in stop sequence check slowing down streaming.
|
|
@ -775,20 +775,22 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
|
remaining_text = self.detokenize(remaining_tokens)
|
||||||
|
remaining_length = len(remaining_text)
|
||||||
|
|
||||||
# We want to avoid yielding any characters from
|
# We want to avoid yielding any characters from
|
||||||
# the generated text if they are part of a stop
|
# the generated text if they are part of a stop
|
||||||
# sequence.
|
# sequence.
|
||||||
first_stop_position = 0
|
first_stop_position = 0
|
||||||
for s in stop_sequences:
|
for s in stop_sequences:
|
||||||
for i in range(len(s), 0, -1):
|
for i in range(min(len(s), remaining_length), 0, -1):
|
||||||
if all_text.endswith(s[:i]):
|
if remaining_text.endswith(s[:i]):
|
||||||
if i > first_stop_position:
|
if i > first_stop_position:
|
||||||
first_stop_position = i
|
first_stop_position = i
|
||||||
break
|
break
|
||||||
|
|
||||||
token_end_position = 0
|
token_end_position = 0
|
||||||
remaining_tokens = completion_tokens[returned_tokens:]
|
|
||||||
remaining_length = len(self.detokenize(remaining_tokens))
|
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
token_end_position += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token]))
|
||||||
# Check if stop sequence is in the token
|
# Check if stop sequence is in the token
|
||||||
|
|
Loading…
Reference in a new issue