llama server wrapper
This commit is contained in:
parent
8fa91332fa
commit
0758cb2d4b
7 changed files with 83 additions and 138 deletions
34
server/README.md
Normal file
34
server/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Server
|
||||
|
||||
🙊
|
||||
|
||||
## Installation
|
||||
|
||||
If using Apple silicon, you need a Python version that supports arm64:
|
||||
|
||||
```bash
|
||||
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
|
||||
bash Miniforge3-MacOSX-arm64.sh
|
||||
```
|
||||
|
||||
Get the dependencies:
|
||||
|
||||
```bash
|
||||
pip install llama-cpp-python
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
Put your model in `models/` and run:
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
### `POST /generate`
|
||||
|
||||
model: `string` - The name of the model to use in the `models` folder.
|
||||
prompt: `string` - The prompt to use.
|
|
@ -1,2 +0,0 @@
|
|||
LIBRARY_PATH=$PWD/go-llama.cpp C_INCLUDE_PATH=$PWD/go-llama.cpp go build .
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
module github.com/keypairdev/keypair
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1
|
||||
github.com/sashabaranov/go-openai v1.11.3
|
||||
)
|
|
@ -1,15 +0,0 @@
|
|||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 h1:UQ8y3kHxBgh3BnaW06y/X97fEN48yHPwWobMz8/aztU=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
|
||||
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
|
||||
github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
|
||||
github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
113
server/main.go
113
server/main.go
|
@ -1,113 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
|
||||
type Model interface {
|
||||
Name() string
|
||||
Handler(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
type LLama7B struct {
|
||||
llama *llama.LLama
|
||||
}
|
||||
|
||||
func NewLLama7B() *LLama7B {
|
||||
llama, err := llama.New("./models/7B/ggml-model-q4_0.bin", llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(128))
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return &LLama7B{
|
||||
llama: llama,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LLama7B) Name() string {
|
||||
return "LLaMA 7B"
|
||||
}
|
||||
|
||||
func (m *LLama7B) Handler(w http.ResponseWriter, r *http.Request) {
|
||||
var text bytes.Buffer
|
||||
io.Copy(&text, r.Body)
|
||||
|
||||
_, err := m.llama.Predict(text.String(), llama.Debug, llama.SetTokenCallback(func(token string) bool {
|
||||
w.Write([]byte(token))
|
||||
return true
|
||||
}), llama.SetTokens(512), llama.SetThreads(runtime.NumCPU()), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("Predict failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
embeds, err := m.llama.Embeddings(text.String())
|
||||
if err != nil {
|
||||
fmt.Printf("Embeddings: error %s \n", err.Error())
|
||||
}
|
||||
fmt.Printf("Embeddings: %v", embeds)
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
type GPT4 struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (g *GPT4) Name() string {
|
||||
return "OpenAI GPT-4"
|
||||
}
|
||||
|
||||
func (g *GPT4) Handler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
client := openai.NewClient("your token")
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Printf("chat completion error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(resp.Choices[0].Message.Content)
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// TODO: add subcommands to spawn different models
|
||||
func main() {
|
||||
model := &LLama7B{}
|
||||
|
||||
http.HandleFunc("/generate", model.Handler)
|
||||
|
||||
fmt.Println("Starting server on :8080")
|
||||
if err := http.ListenAndServe(":8080", nil); err != nil {
|
||||
fmt.Printf("Error starting server: %s\n", err)
|
||||
return
|
||||
}
|
||||
}
|
2
server/requirements.txt
Normal file
2
server/requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
Flask==2.3.2
|
||||
flask_cors==3.0.10
|
47
server/server.py
Normal file
47
server/server.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
import os
|
||||
from llama_cpp import Llama
|
||||
from flask import Flask, Response, stream_with_context, request
|
||||
from flask_cors import CORS, cross_origin
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app) # enable CORS for all routes
|
||||
|
||||
# llms tracks which models are loaded
|
||||
llms = {}
|
||||
|
||||
|
||||
@app.route("/generate", methods=["POST"])
|
||||
def generate():
|
||||
data = request.get_json()
|
||||
model = data.get("model")
|
||||
prompt = data.get("prompt")
|
||||
|
||||
if not model:
|
||||
return Response("Model is required", status=400)
|
||||
if not prompt:
|
||||
return Response("Prompt is required", status=400)
|
||||
if not os.path.exists(f"../models/{model}.bin"):
|
||||
return {"error": "The model file does not exist."}, 400
|
||||
|
||||
if model not in llms:
|
||||
llms[model] = Llama(model_path=f"../models/{model}.bin")
|
||||
|
||||
def stream_response():
|
||||
stream = llms[model](
|
||||
str(prompt), # TODO: optimize prompt based on model
|
||||
max_tokens=4096,
|
||||
stop=["Q:", "\n"],
|
||||
echo=True,
|
||||
stream=True,
|
||||
)
|
||||
for output in stream:
|
||||
yield json.dumps(output)
|
||||
|
||||
return Response(
|
||||
stream_with_context(stream_response()), mimetype="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True, threaded=True, port=5000)
|
Loading…
Reference in a new issue