Merge pull request #3464 from dhiltgen/subprocess

Fix numgpu opt miscomparison
This commit is contained in:
Daniel Hiltgen 2024-04-02 20:10:17 -07:00 committed by GitHub
commit 464d817824
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View file

@ -33,14 +33,14 @@ type LlamaServer struct {
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options *api.Options
options api.Options
}
var cpuOnlyFamilies = []string{
"mamba",
}
func NewLlamaServer(model string, adapters, projectors []string, opts *api.Options) (*LlamaServer, error) {
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}

View file

@ -69,7 +69,7 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.Duration) error {
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
@ -107,7 +107,7 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = opts
loaded.Options = &opts
}
if loaded.expireTimer == nil {
@ -220,7 +220,7 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, &opts, sessionDuration); err != nil {
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@ -465,7 +465,7 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, &opts, sessionDuration); err != nil {
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@ -1272,7 +1272,7 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, &opts, sessionDuration); err != nil {
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}