Extract generate method
This commit is contained in:
parent
1c823f6d0f
commit
30fc0f3866
1 changed files with 22 additions and 12 deletions
|
@ -128,6 +128,20 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
|
||||||
|
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
|
||||||
|
last_n_tokens.extend(past_tokens)
|
||||||
|
for i in range(max_tokens):
|
||||||
|
token = self._sample(
|
||||||
|
last_n_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
temp=temp,
|
||||||
|
repeat_penalty=repeat_penalty
|
||||||
|
)
|
||||||
|
yield token
|
||||||
|
self._eval([token], len(past_tokens) + i)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -162,8 +176,9 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
Response object containing the generated text.
|
Response object containing the generated text.
|
||||||
"""
|
"""
|
||||||
|
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||||
|
created= int(time.time())
|
||||||
text = b""
|
text = b""
|
||||||
finish_reason = "length"
|
|
||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
|
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
|
||||||
|
|
||||||
|
@ -182,14 +197,8 @@ class Llama:
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
stop = [s.encode("utf-8") for s in stop]
|
stop = [s.encode("utf-8") for s in stop]
|
||||||
|
|
||||||
for i in range(max_tokens):
|
finish_reason = None
|
||||||
token = self._sample(
|
for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
|
||||||
last_n_tokens,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
temp=temperature,
|
|
||||||
repeat_penalty=repeat_penalty
|
|
||||||
)
|
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
@ -204,7 +213,8 @@ class Llama:
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
self._eval([token], len(prompt_tokens) + len(completion_tokens))
|
if finish_reason is None:
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
text = text.decode("utf-8")
|
text = text.decode("utf-8")
|
||||||
|
|
||||||
|
@ -220,9 +230,9 @@ class Llama:
|
||||||
)[:logprobs]
|
)[:logprobs]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": int(time.time()),
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Reference in a new issue