Load Embedding Model on Empty Input (#6325)
* load on empty input * no load on invalid input
This commit is contained in:
parent
01b80e9ffc
commit
8b00a415ab
2 changed files with 9 additions and 77 deletions
|
@ -324,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
input = append(input, v.(string))
|
input = append(input, v.(string))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
if req.Input != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(input) == 0 {
|
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
|
@ -341,6 +338,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
if len(input) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
kvData, err := getKVData(m.ModelPath, false)
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
|
|
@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "Embed Handler Empty Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: "",
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var embedResp api.EmbedResponse
|
|
||||||
err = json.Unmarshal(body, &embedResp)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Model != "t-bone" {
|
|
||||||
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Embeddings == nil {
|
|
||||||
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embedResp.Embeddings) != 0 {
|
|
||||||
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Embed Handler Invalid Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: 2,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
_, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
Loading…
Reference in a new issue