Extract generate method
This commit is contained in:
parent
1c823f6d0f
commit
30fc0f3866
1 changed files with 22 additions and 12 deletions
|
@ -128,6 +128,20 @@ class Llama:
|
|||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
|
||||
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
|
||||
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
|
||||
last_n_tokens.extend(past_tokens)
|
||||
for i in range(max_tokens):
|
||||
token = self._sample(
|
||||
last_n_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty
|
||||
)
|
||||
yield token
|
||||
self._eval([token], len(past_tokens) + i)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -162,8 +176,9 @@ class Llama:
|
|||
Returns:
|
||||
Response object containing the generated text.
|
||||
"""
|
||||
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||
created= int(time.time())
|
||||
text = b""
|
||||
finish_reason = "length"
|
||||
completion_tokens = []
|
||||
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
|
||||
|
||||
|
@ -182,14 +197,8 @@ class Llama:
|
|||
if stop is not None:
|
||||
stop = [s.encode("utf-8") for s in stop]
|
||||
|
||||
for i in range(max_tokens):
|
||||
token = self._sample(
|
||||
last_n_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty
|
||||
)
|
||||
finish_reason = None
|
||||
for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
@ -204,7 +213,8 @@ class Llama:
|
|||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
self._eval([token], len(prompt_tokens) + len(completion_tokens))
|
||||
if finish_reason is None:
|
||||
finish_reason = "length"
|
||||
|
||||
text = text.decode("utf-8")
|
||||
|
||||
|
@ -220,9 +230,9 @@ class Llama:
|
|||
)[:logprobs]
|
||||
|
||||
return {
|
||||
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue