use ffi for tokenizing/detokenizing
This commit is contained in:
parent
646371f56d
commit
26a00a0410
3 changed files with 57 additions and 144 deletions
43
llm/ext_server/server.cpp
vendored
43
llm/ext_server/server.cpp
vendored
|
@ -2714,21 +2714,6 @@ static json format_partial_response(
|
|||
return res;
|
||||
}
|
||||
|
||||
static json format_tokenizer_response(const std::vector<llama_token> &tokens)
|
||||
{
|
||||
return json {
|
||||
{"tokens", tokens}
|
||||
};
|
||||
}
|
||||
|
||||
static json format_detokenized_response(std::string content)
|
||||
{
|
||||
return json {
|
||||
{"content", content}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
static void log_server_request(const httplib::Request &req, const httplib::Response &res)
|
||||
{
|
||||
// skip GH copilot requests when using default port
|
||||
|
@ -3218,34 +3203,6 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
});
|
||||
|
||||
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0)
|
||||
{
|
||||
tokens = llama.tokenize(body["content"], false);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0)
|
||||
{
|
||||
const std::vector<llama_token> tokens = body["tokens"];
|
||||
content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
|
||||
const json data = format_detokenized_response(content);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
|
45
llm/llm.go
45
llm/llm.go
|
@ -12,6 +12,7 @@ package llm
|
|||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
@ -37,3 +38,47 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
type llamaModel struct {
|
||||
m *C.struct_llama_model
|
||||
}
|
||||
|
||||
func newLlamaModel(p string) *llamaModel {
|
||||
cs := C.CString(p)
|
||||
defer C.free(unsafe.Pointer(cs))
|
||||
|
||||
return &llamaModel{
|
||||
C.llama_load_model_from_file(
|
||||
cs,
|
||||
C.llama_model_default_params(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *llamaModel) Close() {
|
||||
C.llama_free_model(llm.m)
|
||||
}
|
||||
|
||||
func (llm *llamaModel) Tokenize(s string) []int {
|
||||
cs := C.CString(s)
|
||||
defer C.free(unsafe.Pointer(cs))
|
||||
|
||||
tokens := make([]int, len(s)+2)
|
||||
if n := C.llama_tokenize(llm.m, cs, C.int(len(s)), (*C.llama_token)(unsafe.Pointer(&tokens[0])), C.int(len(s)+2), false, true); n > 0 {
|
||||
return tokens[:n]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *llamaModel) Detokenize(i32s []int) string {
|
||||
var sb strings.Builder
|
||||
for _, i32 := range i32s {
|
||||
c := make([]byte, 512)
|
||||
if n := C.llama_token_to_piece(llm.m, C.llama_token(i32), (*C.char)(unsafe.Pointer(&c[0])), C.int(len(c)), false); n > 0 {
|
||||
sb.WriteString(unsafe.String(&c[0], n))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
|
113
llm/server.go
113
llm/server.go
|
@ -57,6 +57,8 @@ type llmServer struct {
|
|||
loadDuration time.Duration // Record how long it took the model to load
|
||||
loadProgress float32
|
||||
|
||||
*llamaModel
|
||||
|
||||
sem *semaphore.Weighted
|
||||
}
|
||||
|
||||
|
@ -306,6 +308,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||
totalLayers: ggml.KV().BlockCount() + 1,
|
||||
gpuCount: gpuCount,
|
||||
done: make(chan error, 1),
|
||||
llamaModel: newLlamaModel(model),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
|
@ -843,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(EmbeddingRequest{Content: prompt}); err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), &b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
|
@ -878,108 +881,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
return s.llamaModel.Tokenize(content), nil
|
||||
}
|
||||
|
||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
return s.llamaModel.Detokenize(tokens), nil
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
|
@ -997,6 +904,10 @@ func (s *llmServer) Close() error {
|
|||
slog.Debug("llama server stopped")
|
||||
}
|
||||
|
||||
if s.llamaModel != nil {
|
||||
s.llamaModel.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue