diff --git a/server/routes.go b/server/routes.go index 8184db75..a745fb20 100644 --- a/server/routes.go +++ b/server/routes.go @@ -188,21 +188,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var b bytes.Buffer - if err := tmpl.Execute(&b, values); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if req.Context != nil { s, err := r.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - prompt = s + b.String() - } else { - prompt = b.String(); + b.WriteString(s) } + + if err := tmpl.Execute(&b, values); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + prompt = b.String() } slog.Debug("generate request", "prompt", prompt, "images", images) @@ -241,12 +241,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) if !req.Raw { - tokens, err := r.Tokenize(c.Request.Context(), prompt + sb.String()) + tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { ch <- gin.H{"error": err.Error()} return } - res.Context = tokens[:] + res.Context = tokens } }