Fix ctypes typing issue for Arrays

This commit is contained in:
Andrei Betlen 2023-03-31 03:20:15 -04:00
parent 1545b22727
commit 670d390001

View file

@ -1,14 +1,6 @@
import ctypes
from ctypes import (
c_int,
c_float,
c_char_p,
c_void_p,
c_bool,
POINTER,
Structure,
)
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array
import pathlib
from itertools import chain
@ -116,7 +108,7 @@ _lib.llama_model_quantize.restype = c_int
# Returns 0 on success
def llama_eval(
ctx: llama_context_p,
tokens: ctypes.Array[llama_token],
tokens, # type: Array[llama_token]
n_tokens: c_int,
n_past: c_int,
n_threads: c_int,
@ -136,7 +128,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
tokens: ctypes.Array[llama_token],
tokens, # type: Array[llama_token]
n_max_tokens: c_int,
add_bos: c_bool,
) -> c_int:
@ -176,7 +168,7 @@ _lib.llama_n_embd.restype = c_int
# Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens
# Cols: n_vocab
def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
def llama_get_logits(ctx: llama_context_p):
return _lib.llama_get_logits(ctx)
@ -186,7 +178,7 @@ _lib.llama_get_logits.restype = POINTER(c_float)
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]:
def llama_get_embeddings(ctx: llama_context_p):
return _lib.llama_get_embeddings(ctx)
@ -224,7 +216,7 @@ _lib.llama_token_eos.restype = llama_token
# TODO: improve the last_n_tokens interface ?
def llama_sample_top_p_top_k(
ctx: llama_context_p,
last_n_tokens_data: ctypes.Array[llama_token],
last_n_tokens_data, # type: Array[llama_token]
last_n_tokens_size: c_int,
top_k: c_int,
top_p: c_float,