parent
c57317cbf0
commit
e8b954c646
3 changed files with 53 additions and 3 deletions
|
@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
|
if c.Name == "template" {
|
||||||
|
if _, err := template.Parse(c.Args); err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
|
|
|
@ -56,6 +56,7 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var errRequired = errors.New("is required")
|
||||||
|
var errBadTemplate = errors.New("template error")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
|
@ -609,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||||
|
if errors.Is(err, errBadTemplate) {
|
||||||
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
|
}
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if r.Stream != nil && !*r.Stream {
|
if r.Stream != nil && !*r.Stream {
|
||||||
|
@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
|
status, ok := r["status"].(int)
|
||||||
|
if !ok {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
if errorMsg, ok := r["error"].(string); ok {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
c.JSON(status, gin.H{"error": errorMsg})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
if string(system) != "Say bye!" {
|
if string(system) != "Say bye!" {
|
||||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("incomplete template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with unclosed if", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with undefined function", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue