This commit is contained in:
commit
9fa4a19138
10 changed files with 414 additions and 38 deletions
|
@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||
|
|
|
@ -146,6 +146,7 @@ type ToolCall struct {
|
|||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ make apply-patches
|
|||
|
||||
**Pin to new base commit**
|
||||
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||
|
||||
#### Applying patches
|
||||
|
||||
|
|
|
@ -833,10 +833,21 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
lpath string,
|
||||
lpath multiLPath,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
flashAttention bool,
|
||||
|
@ -857,12 +868,14 @@ func (s *Server) loadModel(
|
|||
panic(err)
|
||||
}
|
||||
|
||||
if lpath != "" {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
|
||||
if lpath.String() != "" {
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ppath != "" {
|
||||
var err error
|
||||
|
@ -890,7 +903,6 @@ func main() {
|
|||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
lpath := flag.String("lora", "", "Path to lora layer file")
|
||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
|
@ -900,6 +912,9 @@ func main() {
|
|||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||
|
||||
var lpaths multiLPath
|
||||
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
flag.Parse()
|
||||
if *requirements {
|
||||
printRequirements(os.Stdout)
|
||||
|
@ -946,7 +961,7 @@ func main() {
|
|||
params := llama.ModelParams{
|
||||
NumGpuLayers: *nGpuLayers,
|
||||
MainGpu: *mainGpu,
|
||||
UseMmap: !*noMmap && *lpath == "",
|
||||
UseMmap: !*noMmap && lpaths.String() == "",
|
||||
UseMlock: *mlock,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
Progress: func(progress float32) {
|
||||
|
@ -955,7 +970,7 @@ func main() {
|
|||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
|
|
|
@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||
// Loop through potential servers
|
||||
finalErr := errors.New("no suitable llama servers found")
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
|
||||
rDir, err := runners.Refresh(build.EmbedFS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
for _, adapter := range adapters {
|
||||
params = append(params, "--lora", adapter)
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
|
|
|
@ -140,6 +140,7 @@ type CompletionChunk struct {
|
|||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
|
@ -200,12 +201,13 @@ func toolCallId() string {
|
|||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||
for i, tc := range r.Message.ToolCalls {
|
||||
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
toolCalls := make([]ToolCall, len(tc))
|
||||
for i, tc := range tc {
|
||||
toolCalls[i].ID = toolCallId()
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
toolCalls[i].Index = tc.Function.Index
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
|
@ -215,7 +217,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||
|
||||
toolCalls[i].Function.Arguments = string(args)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
|
@ -244,6 +250,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
|
@ -252,7 +259,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []ChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
|
|
|
@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
|
|||
Stream: &False,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
|
|
|
@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
|
|||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
|
|
|
@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
|
@ -1492,7 +1495,37 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||
// handlers
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
}
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -25,10 +26,14 @@ type mockRunner struct {
|
|||
// CompletionRequest is only valid until the next call to Completion
|
||||
llm.CompletionRequest
|
||||
llm.CompletionResponse
|
||||
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
||||
}
|
||||
|
||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
m.CompletionRequest = r
|
||||
if m.CompletionFn != nil {
|
||||
return m.CompletionFn(ctx, r, fn)
|
||||
}
|
||||
fn(m.CompletionResponse)
|
||||
return nil
|
||||
}
|
||||
|
@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
|||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s
|
||||
TEMPLATE """
|
||||
{{- if .System }}System: {{ .System }} {{ end }}
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- .Role }}: {{ .Content }}
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
{{ end }}"""
|
||||
`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
|
@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create test-system model: %d", w.Code)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mock.CompletionResponse = llm.CompletionResponse{
|
||||
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "done",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
}
|
||||
|
||||
streamRequest := true
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
||||
t.Logf("Failed to decode error response: %v", err)
|
||||
} else {
|
||||
t.Logf("Error response: %s", errResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.ToolCalls == nil {
|
||||
t.Error("expected tool calls, got nil")
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate streaming response with multiple chunks
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Send chunks with small delays to simulate streaming
|
||||
responses := []llm.CompletionResponse{
|
||||
{
|
||||
Content: `{"name":"get_`,
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `weather","arguments":{"location":"Seattle`,
|
||||
Done: false,
|
||||
PromptEvalCount: 2,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "tool_call",
|
||||
PromptEvalCount: 3,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Read and validate the streamed responses
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var finalToolCall api.ToolCall
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Done {
|
||||
if len(resp.Message.ToolCalls) != 1 {
|
||||
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||
}
|
||||
finalToolCall = resp.Message.ToolCalls[0]
|
||||
}
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue