fix: k8s dependency version: emicklei/go-restful

`emicklei/go-restful` is used by:
- `k8s.io/client-go`  (Godeps)

Refs:
- e121606b0d/Godeps/Godeps.json
This commit is contained in:
Fernandez Ludovic 2017-06-11 18:03:28 +02:00 committed by Ludovic Fernandez
parent a7297b49a4
commit 5aa017d9b5
26 changed files with 957 additions and 380 deletions

2
glide.lock generated
View file

@ -194,7 +194,7 @@ imports:
- name: github.com/elazarl/go-bindata-assetfs - name: github.com/elazarl/go-bindata-assetfs
version: 30f82fa23fd844bd5bb1e5f216db87fd77b5eb43 version: 30f82fa23fd844bd5bb1e5f216db87fd77b5eb43
- name: github.com/emicklei/go-restful - name: github.com/emicklei/go-restful
version: 892402ba11a2e2fd5e1295dd633481f27365f14d version: 89ef8af493ab468a45a42bb0d89a06fccdd2fb22
subpackages: subpackages:
- log - log
- swagger - swagger

View file

@ -5,10 +5,12 @@ package restful
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
import ( import (
"bufio"
"compress/gzip" "compress/gzip"
"compress/zlib" "compress/zlib"
"errors" "errors"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
) )
@ -20,6 +22,7 @@ var EnableContentEncoding = false
type CompressingResponseWriter struct { type CompressingResponseWriter struct {
writer http.ResponseWriter writer http.ResponseWriter
compressor io.WriteCloser compressor io.WriteCloser
encoding string
} }
// Header is part of http.ResponseWriter interface // Header is part of http.ResponseWriter interface
@ -35,6 +38,9 @@ func (c *CompressingResponseWriter) WriteHeader(status int) {
// Write is part of http.ResponseWriter interface // Write is part of http.ResponseWriter interface
// It is passed through the compressor // It is passed through the compressor
func (c *CompressingResponseWriter) Write(bytes []byte) (int, error) { func (c *CompressingResponseWriter) Write(bytes []byte) (int, error) {
if c.isCompressorClosed() {
return -1, errors.New("Compressing error: tried to write data using closed compressor")
}
return c.compressor.Write(bytes) return c.compressor.Write(bytes)
} }
@ -44,8 +50,36 @@ func (c *CompressingResponseWriter) CloseNotify() <-chan bool {
} }
// Close the underlying compressor // Close the underlying compressor
func (c *CompressingResponseWriter) Close() { func (c *CompressingResponseWriter) Close() error {
if c.isCompressorClosed() {
return errors.New("Compressing error: tried to close already closed compressor")
}
c.compressor.Close() c.compressor.Close()
if ENCODING_GZIP == c.encoding {
currentCompressorProvider.ReleaseGzipWriter(c.compressor.(*gzip.Writer))
}
if ENCODING_DEFLATE == c.encoding {
currentCompressorProvider.ReleaseZlibWriter(c.compressor.(*zlib.Writer))
}
// gc hint needed?
c.compressor = nil
return nil
}
func (c *CompressingResponseWriter) isCompressorClosed() bool {
return nil == c.compressor
}
// Hijack implements the Hijacker interface
// This is especially useful when combining Container.EnabledContentEncoding
// in combination with websockets (for instance gorilla/websocket)
func (c *CompressingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := c.writer.(http.Hijacker)
if !ok {
return nil, nil, errors.New("ResponseWriter doesn't support Hijacker interface")
}
return hijacker.Hijack()
} }
// WantsCompressedResponse reads the Accept-Encoding header to see if and which encoding is requested. // WantsCompressedResponse reads the Accept-Encoding header to see if and which encoding is requested.
@ -73,13 +107,15 @@ func NewCompressingResponseWriter(httpWriter http.ResponseWriter, encoding strin
c.writer = httpWriter c.writer = httpWriter
var err error var err error
if ENCODING_GZIP == encoding { if ENCODING_GZIP == encoding {
w := GzipWriterPool.Get().(*gzip.Writer) w := currentCompressorProvider.AcquireGzipWriter()
w.Reset(httpWriter) w.Reset(httpWriter)
c.compressor = w c.compressor = w
c.encoding = ENCODING_GZIP
} else if ENCODING_DEFLATE == encoding { } else if ENCODING_DEFLATE == encoding {
w := ZlibWriterPool.Get().(*zlib.Writer) w := currentCompressorProvider.AcquireZlibWriter()
w.Reset(httpWriter) w.Reset(httpWriter)
c.compressor = w c.compressor = w
c.encoding = ENCODING_DEFLATE
} else { } else {
return nil, errors.New("Unknown encoding:" + encoding) return nil, errors.New("Unknown encoding:" + encoding)
} }

View file

@ -0,0 +1,103 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"compress/gzip"
"compress/zlib"
)
// BoundedCachedCompressors is a CompressorProvider that uses a cache with a fixed amount
// of writers and readers (resources).
// If a new resource is acquired and all are in use, it will return a new unmanaged resource.
type BoundedCachedCompressors struct {
gzipWriters chan *gzip.Writer
gzipReaders chan *gzip.Reader
zlibWriters chan *zlib.Writer
writersCapacity int
readersCapacity int
}
// NewBoundedCachedCompressors returns a new, with filled cache, BoundedCachedCompressors.
func NewBoundedCachedCompressors(writersCapacity, readersCapacity int) *BoundedCachedCompressors {
b := &BoundedCachedCompressors{
gzipWriters: make(chan *gzip.Writer, writersCapacity),
gzipReaders: make(chan *gzip.Reader, readersCapacity),
zlibWriters: make(chan *zlib.Writer, writersCapacity),
writersCapacity: writersCapacity,
readersCapacity: readersCapacity,
}
for ix := 0; ix < writersCapacity; ix++ {
b.gzipWriters <- newGzipWriter()
b.zlibWriters <- newZlibWriter()
}
for ix := 0; ix < readersCapacity; ix++ {
b.gzipReaders <- newGzipReader()
}
return b
}
// AcquireGzipWriter returns an resettable *gzip.Writer. Needs to be released.
func (b *BoundedCachedCompressors) AcquireGzipWriter() *gzip.Writer {
var writer *gzip.Writer
select {
case writer, _ = <-b.gzipWriters:
default:
// return a new unmanaged one
writer = newGzipWriter()
}
return writer
}
// ReleaseGzipWriter accepts a writer (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseGzipWriter(w *gzip.Writer) {
// forget the unmanaged ones
if len(b.gzipWriters) < b.writersCapacity {
b.gzipWriters <- w
}
}
// AcquireGzipReader returns a *gzip.Reader. Needs to be released.
func (b *BoundedCachedCompressors) AcquireGzipReader() *gzip.Reader {
var reader *gzip.Reader
select {
case reader, _ = <-b.gzipReaders:
default:
// return a new unmanaged one
reader = newGzipReader()
}
return reader
}
// ReleaseGzipReader accepts a reader (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseGzipReader(r *gzip.Reader) {
// forget the unmanaged ones
if len(b.gzipReaders) < b.readersCapacity {
b.gzipReaders <- r
}
}
// AcquireZlibWriter returns an resettable *zlib.Writer. Needs to be released.
func (b *BoundedCachedCompressors) AcquireZlibWriter() *zlib.Writer {
var writer *zlib.Writer
select {
case writer, _ = <-b.zlibWriters:
default:
// return a new unmanaged one
writer = newZlibWriter()
}
return writer
}
// ReleaseZlibWriter accepts a writer (does not have to be one that was cached)
// only when the cache has room for it. It will ignore it otherwise.
func (b *BoundedCachedCompressors) ReleaseZlibWriter(w *zlib.Writer) {
// forget the unmanaged ones
if len(b.zlibWriters) < b.writersCapacity {
b.zlibWriters <- w
}
}

View file

@ -1,5 +1,9 @@
package restful package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
@ -7,12 +11,50 @@ import (
"sync" "sync"
) )
// GzipWriterPool is used to get reusable zippers. // SyncPoolCompessors is a CompressorProvider that use the standard sync.Pool.
// The Get() result must be type asserted to *gzip.Writer. type SyncPoolCompessors struct {
var GzipWriterPool = &sync.Pool{ GzipWriterPool *sync.Pool
New: func() interface{} { GzipReaderPool *sync.Pool
return newGzipWriter() ZlibWriterPool *sync.Pool
}
// NewSyncPoolCompessors returns a new ("empty") SyncPoolCompessors.
func NewSyncPoolCompessors() *SyncPoolCompessors {
return &SyncPoolCompessors{
GzipWriterPool: &sync.Pool{
New: func() interface{} { return newGzipWriter() },
}, },
GzipReaderPool: &sync.Pool{
New: func() interface{} { return newGzipReader() },
},
ZlibWriterPool: &sync.Pool{
New: func() interface{} { return newZlibWriter() },
},
}
}
func (s *SyncPoolCompessors) AcquireGzipWriter() *gzip.Writer {
return s.GzipWriterPool.Get().(*gzip.Writer)
}
func (s *SyncPoolCompessors) ReleaseGzipWriter(w *gzip.Writer) {
s.GzipWriterPool.Put(w)
}
func (s *SyncPoolCompessors) AcquireGzipReader() *gzip.Reader {
return s.GzipReaderPool.Get().(*gzip.Reader)
}
func (s *SyncPoolCompessors) ReleaseGzipReader(r *gzip.Reader) {
s.GzipReaderPool.Put(r)
}
func (s *SyncPoolCompessors) AcquireZlibWriter() *zlib.Writer {
return s.ZlibWriterPool.Get().(*zlib.Writer)
}
func (s *SyncPoolCompessors) ReleaseZlibWriter(w *zlib.Writer) {
s.ZlibWriterPool.Put(w)
} }
func newGzipWriter() *gzip.Writer { func newGzipWriter() *gzip.Writer {
@ -24,17 +66,11 @@ func newGzipWriter() *gzip.Writer {
return writer return writer
} }
// GzipReaderPool is used to get reusable zippers.
// The Get() result must be type asserted to *gzip.Reader.
var GzipReaderPool = &sync.Pool{
New: func() interface{} {
return newGzipReader()
},
}
func newGzipReader() *gzip.Reader { func newGzipReader() *gzip.Reader {
// create with an empty reader (but with GZIP header); it will be replaced before using the gzipReader // create with an empty reader (but with GZIP header); it will be replaced before using the gzipReader
w := GzipWriterPool.Get().(*gzip.Writer) // we can safely use currentCompressProvider because it is set on package initialization.
w := currentCompressorProvider.AcquireGzipWriter()
defer currentCompressorProvider.ReleaseGzipWriter(w)
b := new(bytes.Buffer) b := new(bytes.Buffer)
w.Reset(b) w.Reset(b)
w.Flush() w.Flush()
@ -46,14 +82,6 @@ func newGzipReader() *gzip.Reader {
return reader return reader
} }
// ZlibWriterPool is used to get reusable zippers.
// The Get() result must be type asserted to *zlib.Writer.
var ZlibWriterPool = &sync.Pool{
New: func() interface{} {
return newZlibWriter()
},
}
func newZlibWriter() *zlib.Writer { func newZlibWriter() *zlib.Writer {
writer, err := zlib.NewWriterLevel(new(bytes.Buffer), gzip.BestSpeed) writer, err := zlib.NewWriterLevel(new(bytes.Buffer), gzip.BestSpeed)
if err != nil { if err != nil {

53
vendor/github.com/emicklei/go-restful/compressors.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"compress/gzip"
"compress/zlib"
)
type CompressorProvider interface {
// Returns a *gzip.Writer which needs to be released later.
// Before using it, call Reset().
AcquireGzipWriter() *gzip.Writer
// Releases an aqcuired *gzip.Writer.
ReleaseGzipWriter(w *gzip.Writer)
// Returns a *gzip.Reader which needs to be released later.
AcquireGzipReader() *gzip.Reader
// Releases an aqcuired *gzip.Reader.
ReleaseGzipReader(w *gzip.Reader)
// Returns a *zlib.Writer which needs to be released later.
// Before using it, call Reset().
AcquireZlibWriter() *zlib.Writer
// Releases an aqcuired *zlib.Writer.
ReleaseZlibWriter(w *zlib.Writer)
}
// DefaultCompressorProvider is the actual provider of compressors (zlib or gzip).
var currentCompressorProvider CompressorProvider
func init() {
currentCompressorProvider = NewSyncPoolCompessors()
}
// CurrentCompressorProvider returns the current CompressorProvider.
// It is initialized using a SyncPoolCompessors.
func CurrentCompressorProvider() CompressorProvider {
return currentCompressorProvider
}
// CompressorProvider sets the actual provider of compressors (zlib or gzip).
func SetCompressorProvider(p CompressorProvider) {
if p == nil {
panic("cannot set compressor provider to nil")
}
currentCompressorProvider = p
}

View file

@ -6,6 +6,7 @@ package restful
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@ -83,18 +84,42 @@ func (c *Container) EnableContentEncoding(enabled bool) {
c.contentEncodingEnabled = enabled c.contentEncodingEnabled = enabled
} }
// Add a WebService to the Container. It will detect duplicate root paths and panic in that case. // Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
func (c *Container) Add(service *WebService) *Container { func (c *Container) Add(service *WebService) *Container {
c.webServicesLock.Lock() c.webServicesLock.Lock()
defer c.webServicesLock.Unlock() defer c.webServicesLock.Unlock()
// If registered on root then no additional specific mapping is needed
// if rootPath was not set then lazy initialize it
if len(service.rootPath) == 0 {
service.Path("/")
}
// cannot have duplicate root paths
for _, each := range c.webServices {
if each.RootPath() == service.RootPath() {
log.Printf("[restful] WebService with duplicate root path detected:['%v']", each)
os.Exit(1)
}
}
// If not registered on root then add specific mapping
if !c.isRegisteredOnRoot { if !c.isRegisteredOnRoot {
pattern := c.fixedPrefixPath(service.RootPath()) c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
}
c.webServices = append(c.webServices, service)
return c
}
// addHandler may set a new HandleFunc for the serveMux
// this function must run inside the critical region protected by the webServicesLock.
// returns true if the function was registered on root ("/")
func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
pattern := fixedPrefixPath(service.RootPath())
// check if root path registration is needed // check if root path registration is needed
if "/" == pattern || "" == pattern { if "/" == pattern || "" == pattern {
c.ServeMux.HandleFunc("/", c.dispatch) serveMux.HandleFunc("/", c.dispatch)
c.isRegisteredOnRoot = true return true
} else { }
// detect if registration already exists // detect if registration already exists
alreadyMapped := false alreadyMapped := false
for _, each := range c.webServices { for _, each := range c.webServices {
@ -104,38 +129,36 @@ func (c *Container) Add(service *WebService) *Container {
} }
} }
if !alreadyMapped { if !alreadyMapped {
c.ServeMux.HandleFunc(pattern, c.dispatch) serveMux.HandleFunc(pattern, c.dispatch)
if !strings.HasSuffix(pattern, "/") { if !strings.HasSuffix(pattern, "/") {
c.ServeMux.HandleFunc(pattern+"/", c.dispatch) serveMux.HandleFunc(pattern+"/", c.dispatch)
} }
} }
} return false
}
// cannot have duplicate root paths
for _, each := range c.webServices {
if each.RootPath() == service.RootPath() {
log.Printf("[restful] WebService with duplicate root path detected:['%v']", each)
os.Exit(1)
}
}
// if rootPath was not set then lazy initialize it
if len(service.rootPath) == 0 {
service.Path("/")
}
c.webServices = append(c.webServices, service)
return c
} }
func (c *Container) Remove(ws *WebService) error { func (c *Container) Remove(ws *WebService) error {
if c.ServeMux == http.DefaultServeMux {
errMsg := fmt.Sprintf("[restful] cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
log.Printf(errMsg)
return errors.New(errMsg)
}
c.webServicesLock.Lock() c.webServicesLock.Lock()
defer c.webServicesLock.Unlock() defer c.webServicesLock.Unlock()
// build a new ServeMux and re-register all WebServices
newServeMux := http.NewServeMux()
newServices := []*WebService{} newServices := []*WebService{}
for ix := range c.webServices { newIsRegisteredOnRoot := false
if c.webServices[ix].rootPath != ws.rootPath { for _, each := range c.webServices {
newServices = append(newServices, c.webServices[ix]) if each.rootPath != ws.rootPath {
// If not registered on root then add specific mapping
if !newIsRegisteredOnRoot {
newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
}
newServices = append(newServices, each)
} }
} }
c.webServices = newServices c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
return nil return nil
} }
@ -251,7 +274,7 @@ func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.R
} }
// fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {} // fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
func (c Container) fixedPrefixPath(pathspec string) string { func fixedPrefixPath(pathspec string) string {
varBegin := strings.Index(pathspec, "{") varBegin := strings.Index(pathspec, "{")
if -1 == varBegin { if -1 == varBegin {
return pathspec return pathspec
@ -260,19 +283,19 @@ func (c Container) fixedPrefixPath(pathspec string) string {
} }
// ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server // ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
func (c Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) { func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
c.ServeMux.ServeHTTP(httpwriter, httpRequest) c.ServeMux.ServeHTTP(httpwriter, httpRequest)
} }
// Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics. // Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
func (c Container) Handle(pattern string, handler http.Handler) { func (c *Container) Handle(pattern string, handler http.Handler) {
c.ServeMux.Handle(pattern, handler) c.ServeMux.Handle(pattern, handler)
} }
// HandleWithFilter registers the handler for the given pattern. // HandleWithFilter registers the handler for the given pattern.
// Container's filter chain is applied for handler. // Container's filter chain is applied for handler.
// If a handler already exists for pattern, HandleWithFilter panics. // If a handler already exists for pattern, HandleWithFilter panics.
func (c Container) HandleWithFilter(pattern string, handler http.Handler) { func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) { f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
if len(c.containerFilters) == 0 { if len(c.containerFilters) == 0 {
handler.ServeHTTP(httpResponse, httpRequest) handler.ServeHTTP(httpResponse, httpRequest)
@ -295,7 +318,7 @@ func (c *Container) Filter(filter FilterFunction) {
} }
// RegisteredWebServices returns the collections of added WebServices // RegisteredWebServices returns the collections of added WebServices
func (c Container) RegisteredWebServices() []*WebService { func (c *Container) RegisteredWebServices() []*WebService {
c.webServicesLock.RLock() c.webServicesLock.RLock()
defer c.webServicesLock.RUnlock() defer c.webServicesLock.RUnlock()
result := make([]*WebService, len(c.webServices)) result := make([]*WebService, len(c.webServices))
@ -306,7 +329,7 @@ func (c Container) RegisteredWebServices() []*WebService {
} }
// computeAllowedMethods returns a list of HTTP methods that are valid for a Request // computeAllowedMethods returns a list of HTTP methods that are valid for a Request
func (c Container) computeAllowedMethods(req *Request) []string { func (c *Container) computeAllowedMethods(req *Request) []string {
// Go through all RegisteredWebServices() and all its Routes to collect the options // Go through all RegisteredWebServices() and all its Routes to collect the options
methods := []string{} methods := []string{}
requestPath := req.Request.URL.Path requestPath := req.Request.URL.Path

View file

@ -5,6 +5,7 @@ package restful
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
import ( import (
"regexp"
"strconv" "strconv"
"strings" "strings"
) )
@ -19,11 +20,13 @@ import (
type CrossOriginResourceSharing struct { type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names AllowedHeaders []string // list of Header names
AllowedDomains []string // list of allowed values for Http Origin. If empty all are allowed. AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
AllowedMethods []string AllowedMethods []string
MaxAge int // number of seconds before requiring new Options request MaxAge int // number of seconds before requiring new Options request
CookiesAllowed bool CookiesAllowed bool
Container *Container Container *Container
allowedOriginPatterns []*regexp.Regexp // internal field for origin regexp check.
} }
// Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html // Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html
@ -37,22 +40,13 @@ func (c CrossOriginResourceSharing) Filter(req *Request, resp *Response, chain *
chain.ProcessFilter(req, resp) chain.ProcessFilter(req, resp)
return return
} }
if len(c.AllowedDomains) > 0 { // if provided then origin must be included if !c.isOriginAllowed(origin) { // check whether this origin is allowed
included := false
for _, each := range c.AllowedDomains {
if each == origin {
included = true
break
}
}
if !included {
if trace { if trace {
traceLogger.Printf("HTTP Origin:%s is not part of %v", origin, c.AllowedDomains) traceLogger.Printf("HTTP Origin:%s is not part of %v, neither matches any part of %v", origin, c.AllowedDomains, c.allowedOriginPatterns)
} }
chain.ProcessFilter(req, resp) chain.ProcessFilter(req, resp)
return return
} }
}
if req.Request.Method != "OPTIONS" { if req.Request.Method != "OPTIONS" {
c.doActualRequest(req, resp) c.doActualRequest(req, resp)
chain.ProcessFilter(req, resp) chain.ProcessFilter(req, resp)
@ -74,8 +68,12 @@ func (c CrossOriginResourceSharing) doActualRequest(req *Request, resp *Response
func (c *CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response) { func (c *CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response) {
if len(c.AllowedMethods) == 0 { if len(c.AllowedMethods) == 0 {
if c.Container == nil {
c.AllowedMethods = DefaultContainer.computeAllowedMethods(req)
} else {
c.AllowedMethods = c.Container.computeAllowedMethods(req) c.AllowedMethods = c.Container.computeAllowedMethods(req)
} }
}
acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod) acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod)
if !c.isValidAccessControlRequestMethod(acrm, c.AllowedMethods) { if !c.isValidAccessControlRequestMethod(acrm, c.AllowedMethods) {
@ -124,13 +122,32 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
if len(c.AllowedDomains) == 0 { if len(c.AllowedDomains) == 0 {
return true return true
} }
allowed := false allowed := false
for _, each := range c.AllowedDomains { for _, domain := range c.AllowedDomains {
if each == origin { if domain == origin {
allowed = true allowed = true
break break
} }
} }
if !allowed {
if len(c.allowedOriginPatterns) == 0 {
// compile allowed domains to allowed origin patterns
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
if err != nil {
return false
}
c.allowedOriginPatterns = allowedOriginRegexps
}
for _, pattern := range c.allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
}
return allowed return allowed
} }
@ -170,3 +187,16 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str
} }
return false return false
} }
// Take a list of strings and compile them into a list of regular expressions.
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return regexps, err
}
regexps = append(regexps, r)
}
return regexps, nil
}

View file

@ -44,16 +44,16 @@ func (c CurlyRouter) SelectRoute(
} }
// selectRoutes return a collection of Route from a WebService that matches the path tokens from the request. // selectRoutes return a collection of Route from a WebService that matches the path tokens from the request.
func (c CurlyRouter) selectRoutes(ws *WebService, requestTokens []string) []Route { func (c CurlyRouter) selectRoutes(ws *WebService, requestTokens []string) sortableCurlyRoutes {
candidates := &sortableCurlyRoutes{[]*curlyRoute{}} candidates := sortableCurlyRoutes{}
for _, each := range ws.routes { for _, each := range ws.routes {
matches, paramCount, staticCount := c.matchesRouteByPathTokens(each.pathParts, requestTokens) matches, paramCount, staticCount := c.matchesRouteByPathTokens(each.pathParts, requestTokens)
if matches { if matches {
candidates.add(&curlyRoute{each, paramCount, staticCount}) // TODO make sure Routes() return pointers? candidates.add(curlyRoute{each, paramCount, staticCount}) // TODO make sure Routes() return pointers?
} }
} }
sort.Sort(sort.Reverse(candidates)) sort.Sort(sort.Reverse(candidates))
return candidates.routes() return candidates
} }
// matchesRouteByPathTokens computes whether it matches, howmany parameters do match and what the number of static path elements are. // matchesRouteByPathTokens computes whether it matches, howmany parameters do match and what the number of static path elements are.
@ -110,9 +110,9 @@ func (c CurlyRouter) regularMatchesPathToken(routeToken string, colon int, reque
// detectRoute selectes from a list of Route the first match by inspecting both the Accept and Content-Type // detectRoute selectes from a list of Route the first match by inspecting both the Accept and Content-Type
// headers of the Request. See also RouterJSR311 in jsr311.go // headers of the Request. See also RouterJSR311 in jsr311.go
func (c CurlyRouter) detectRoute(candidateRoutes []Route, httpRequest *http.Request) (*Route, error) { func (c CurlyRouter) detectRoute(candidateRoutes sortableCurlyRoutes, httpRequest *http.Request) (*Route, error) {
// tracing is done inside detectRoute // tracing is done inside detectRoute
return RouterJSR311{}.detectRoute(candidateRoutes, httpRequest) return RouterJSR311{}.detectRoute(candidateRoutes.routes(), httpRequest)
} }
// detectWebService returns the best matching webService given the list of path tokens. // detectWebService returns the best matching webService given the list of path tokens.

View file

@ -11,30 +11,28 @@ type curlyRoute struct {
staticCount int staticCount int
} }
type sortableCurlyRoutes struct { type sortableCurlyRoutes []curlyRoute
candidates []*curlyRoute
func (s *sortableCurlyRoutes) add(route curlyRoute) {
*s = append(*s, route)
} }
func (s *sortableCurlyRoutes) add(route *curlyRoute) { func (s sortableCurlyRoutes) routes() (routes []Route) {
s.candidates = append(s.candidates, route) for _, each := range s {
}
func (s *sortableCurlyRoutes) routes() (routes []Route) {
for _, each := range s.candidates {
routes = append(routes, each.route) // TODO change return type routes = append(routes, each.route) // TODO change return type
} }
return routes return routes
} }
func (s *sortableCurlyRoutes) Len() int { func (s sortableCurlyRoutes) Len() int {
return len(s.candidates) return len(s)
} }
func (s *sortableCurlyRoutes) Swap(i, j int) { func (s sortableCurlyRoutes) Swap(i, j int) {
s.candidates[i], s.candidates[j] = s.candidates[j], s.candidates[i] s[i], s[j] = s[j], s[i]
} }
func (s *sortableCurlyRoutes) Less(i, j int) bool { func (s sortableCurlyRoutes) Less(i, j int) bool {
ci := s.candidates[i] ci := s[i]
cj := s.candidates[j] cj := s[j]
// primary key // primary key
if ci.staticCount < cj.staticCount { if ci.staticCount < cj.staticCount {

View file

@ -162,6 +162,11 @@ Default value is false; it will recover from panics. This has performance implic
SetCacheReadEntity controls whether the response data ([]byte) is cached such that ReadEntity is repeatable. SetCacheReadEntity controls whether the response data ([]byte) is cached such that ReadEntity is repeatable.
If you expect to read large amounts of payload data, and you do not use this feature, you should set it to false. If you expect to read large amounts of payload data, and you do not use this feature, you should set it to false.
restful.SetCompressorProvider(NewBoundedCachedCompressors(20, 20))
If content encoding is enabled then the default strategy for getting new gzip/zlib writers and readers is to use a sync.Pool.
Because writers are expensive structures, performance is even more improved when using a preloaded cache. You can also inject your own implementation.
Trouble shooting Trouble shooting
This package has the means to produce detail logging of the complete Http request matching process and filter invocation. This package has the means to produce detail logging of the complete Http request matching process and filter invocation.

View file

@ -0,0 +1,163 @@
package restful
// Copyright 2015 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import (
"encoding/json"
"encoding/xml"
"strings"
"sync"
)
// EntityReaderWriter can read and write values using an encoding such as JSON,XML.
type EntityReaderWriter interface {
// Read a serialized version of the value from the request.
// The Request may have a decompressing reader. Depends on Content-Encoding.
Read(req *Request, v interface{}) error
// Write a serialized version of the value on the response.
// The Response may have a compressing writer. Depends on Accept-Encoding.
// status should be a valid Http Status code
Write(resp *Response, status int, v interface{}) error
}
// entityAccessRegistry is a singleton
var entityAccessRegistry = &entityReaderWriters{
protection: new(sync.RWMutex),
accessors: map[string]EntityReaderWriter{},
}
// entityReaderWriters associates MIME to an EntityReaderWriter
type entityReaderWriters struct {
protection *sync.RWMutex
accessors map[string]EntityReaderWriter
}
func init() {
RegisterEntityAccessor(MIME_JSON, NewEntityAccessorJSON(MIME_JSON))
RegisterEntityAccessor(MIME_XML, NewEntityAccessorXML(MIME_XML))
}
// RegisterEntityAccessor add/overrides the ReaderWriter for encoding content with this MIME type.
func RegisterEntityAccessor(mime string, erw EntityReaderWriter) {
entityAccessRegistry.protection.Lock()
defer entityAccessRegistry.protection.Unlock()
entityAccessRegistry.accessors[mime] = erw
}
// NewEntityAccessorJSON returns a new EntityReaderWriter for accessing JSON content.
// This package is already initialized with such an accessor using the MIME_JSON contentType.
func NewEntityAccessorJSON(contentType string) EntityReaderWriter {
return entityJSONAccess{ContentType: contentType}
}
// NewEntityAccessorXML returns a new EntityReaderWriter for accessing XML content.
// This package is already initialized with such an accessor using the MIME_XML contentType.
func NewEntityAccessorXML(contentType string) EntityReaderWriter {
return entityXMLAccess{ContentType: contentType}
}
// accessorAt returns the registered ReaderWriter for this MIME type.
func (r *entityReaderWriters) accessorAt(mime string) (EntityReaderWriter, bool) {
r.protection.RLock()
defer r.protection.RUnlock()
er, ok := r.accessors[mime]
if !ok {
// retry with reverse lookup
// more expensive but we are in an exceptional situation anyway
for k, v := range r.accessors {
if strings.Contains(mime, k) {
return v, true
}
}
}
return er, ok
}
// entityXMLAccess is a EntityReaderWriter for XML encoding
type entityXMLAccess struct {
// This is used for setting the Content-Type header when writing
ContentType string
}
// Read unmarshalls the value from XML
func (e entityXMLAccess) Read(req *Request, v interface{}) error {
return xml.NewDecoder(req.Request.Body).Decode(v)
}
// Write marshalls the value to JSON and set the Content-Type Header.
func (e entityXMLAccess) Write(resp *Response, status int, v interface{}) error {
return writeXML(resp, status, e.ContentType, v)
}
// writeXML marshalls the value to JSON and set the Content-Type Header.
func writeXML(resp *Response, status int, contentType string, v interface{}) error {
if v == nil {
resp.WriteHeader(status)
// do not write a nil representation
return nil
}
if resp.prettyPrint {
// pretty output must be created and written explicitly
output, err := xml.MarshalIndent(v, " ", " ")
if err != nil {
return err
}
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
_, err = resp.Write([]byte(xml.Header))
if err != nil {
return err
}
_, err = resp.Write(output)
return err
}
// not-so-pretty
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
return xml.NewEncoder(resp).Encode(v)
}
// entityJSONAccess is a EntityReaderWriter for JSON encoding
type entityJSONAccess struct {
// This is used for setting the Content-Type header when writing
ContentType string
}
// Read unmarshalls the value from JSON
func (e entityJSONAccess) Read(req *Request, v interface{}) error {
decoder := json.NewDecoder(req.Request.Body)
decoder.UseNumber()
return decoder.Decode(v)
}
// Write marshalls the value to JSON and set the Content-Type Header.
func (e entityJSONAccess) Write(resp *Response, status int, v interface{}) error {
return writeJSON(resp, status, e.ContentType, v)
}
// write marshalls the value to JSON and set the Content-Type Header.
func writeJSON(resp *Response, status int, contentType string, v interface{}) error {
if v == nil {
resp.WriteHeader(status)
// do not write a nil representation
return nil
}
if resp.prettyPrint {
// pretty output must be created and written explicitly
output, err := json.MarshalIndent(v, " ", " ")
if err != nil {
return err
}
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
_, err = resp.Write(output)
return err
}
// not-so-pretty
resp.Header().Set(HEADER_ContentType, contentType)
resp.WriteHeader(status)
return json.NewEncoder(resp).Encode(v)
}

View file

@ -74,7 +74,7 @@ func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*R
// accept // accept
outputMediaOk := []Route{} outputMediaOk := []Route{}
accept := httpRequest.Header.Get(HEADER_Accept) accept := httpRequest.Header.Get(HEADER_Accept)
if accept == "" { if len(accept) == 0 {
accept = "*/*" accept = "*/*"
} }
for _, each := range inputMediaOk { for _, each := range inputMediaOk {
@ -88,7 +88,8 @@ func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*R
} }
return nil, NewError(http.StatusNotAcceptable, "406: Not Acceptable") return nil, NewError(http.StatusNotAcceptable, "406: Not Acceptable")
} }
return r.bestMatchByMedia(outputMediaOk, contentType, accept), nil // return r.bestMatchByMedia(outputMediaOk, contentType, accept), nil
return &outputMediaOk[0], nil
} }
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2 // http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2

View file

@ -15,7 +15,7 @@ var Logger StdLogger
func init() { func init() {
// default Logger // default Logger
SetLogger(stdlog.New(os.Stdout, "[restful] ", stdlog.LstdFlags|stdlog.Lshortfile)) SetLogger(stdlog.New(os.Stderr, "[restful] ", stdlog.LstdFlags|stdlog.Lshortfile))
} }
func SetLogger(customLogger StdLogger) { func SetLogger(customLogger StdLogger) {

45
vendor/github.com/emicklei/go-restful/mime.go generated vendored Normal file
View file

@ -0,0 +1,45 @@
package restful
import (
"strconv"
"strings"
)
type mime struct {
media string
quality float64
}
// insertMime adds a mime to a list and keeps it sorted by quality.
func insertMime(l []mime, e mime) []mime {
for i, each := range l {
// if current mime has lower quality then insert before
if e.quality > each.quality {
left := append([]mime{}, l[0:i]...)
return append(append(left, e), l[i:]...)
}
}
return append(l, e)
}
// sortedMimes returns a list of mime sorted (desc) by its specified quality.
func sortedMimes(accept string) (sorted []mime) {
for _, each := range strings.Split(accept, ",") {
typeAndQuality := strings.Split(strings.Trim(each, " "), ";")
if len(typeAndQuality) == 1 {
sorted = insertMime(sorted, mime{typeAndQuality[0], 1.0})
} else {
// take factor
parts := strings.Split(typeAndQuality[1], "=")
if len(parts) == 2 {
f, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
traceLogger.Printf("unable to parse quality in %s, %v", each, err)
} else {
sorted = insertMime(sorted, mime{typeAndQuality[0], f})
}
}
}
}
return
}

View file

@ -8,7 +8,8 @@ import "strings"
// OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method // OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method
// and provides the response with a set of allowed methods for the request URL Path. // and provides the response with a set of allowed methods for the request URL Path.
// As for any filter, you can also install it for a particular WebService within a Container // As for any filter, you can also install it for a particular WebService within a Container.
// Note: this filter is not needed when using CrossOriginResourceSharing (for CORS).
func (c *Container) OPTIONSFilter(req *Request, resp *Response, chain *FilterChain) { func (c *Container) OPTIONSFilter(req *Request, resp *Response, chain *FilterChain) {
if "OPTIONS" != req.Request.Method { if "OPTIONS" != req.Request.Method {
chain.ProcessFilter(req, resp) chain.ProcessFilter(req, resp)
@ -19,6 +20,7 @@ func (c *Container) OPTIONSFilter(req *Request, resp *Response, chain *FilterCha
// OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method // OPTIONSFilter is a filter function that inspects the Http Request for the OPTIONS method
// and provides the response with a set of allowed methods for the request URL Path. // and provides the response with a set of allowed methods for the request URL Path.
// Note: this filter is not needed when using CrossOriginResourceSharing (for CORS).
func OPTIONSFilter() FilterFunction { func OPTIONSFilter() FilterFunction {
return DefaultContainer.OPTIONSFilter return DefaultContainer.OPTIONSFilter
} }

View file

@ -30,7 +30,7 @@ type Parameter struct {
// ParameterData represents the state of a Parameter. // ParameterData represents the state of a Parameter.
// It is made public to make it accessible to e.g. the Swagger package. // It is made public to make it accessible to e.g. the Swagger package.
type ParameterData struct { type ParameterData struct {
Name, Description, DataType string Name, Description, DataType, DataFormat string
Kind int Kind int
Required bool Required bool
AllowableValues map[string]string AllowableValues map[string]string
@ -95,6 +95,12 @@ func (p *Parameter) DataType(typeName string) *Parameter {
return p return p
} }
// DataFormat sets the dataFormat field for Swagger UI
func (p *Parameter) DataFormat(formatName string) *Parameter {
p.data.DataFormat = formatName
return p
}
// DefaultValue sets the default value field and returns the receiver // DefaultValue sets the default value field and returns the receiver
func (p *Parameter) DefaultValue(stringRepresentation string) *Parameter { func (p *Parameter) DefaultValue(stringRepresentation string) *Parameter {
p.data.DefaultValue = stringRepresentation p.data.DefaultValue = stringRepresentation

View file

@ -6,14 +6,9 @@ package restful
import ( import (
"bytes" "bytes"
"compress/gzip"
"compress/zlib" "compress/zlib"
"encoding/json"
"encoding/xml"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings"
) )
var defaultRequestContentType string var defaultRequestContentType string
@ -81,62 +76,43 @@ func (r *Request) HeaderParameter(name string) string {
return r.Request.Header.Get(name) return r.Request.Header.Get(name)
} }
// ReadEntity checks the Accept header and reads the content into the entityPointer // ReadEntity checks the Accept header and reads the content into the entityPointer.
// May be called multiple times in the request-response flow
func (r *Request) ReadEntity(entityPointer interface{}) (err error) { func (r *Request) ReadEntity(entityPointer interface{}) (err error) {
defer r.Request.Body.Close()
contentType := r.Request.Header.Get(HEADER_ContentType) contentType := r.Request.Header.Get(HEADER_ContentType)
contentEncoding := r.Request.Header.Get(HEADER_ContentEncoding) contentEncoding := r.Request.Header.Get(HEADER_ContentEncoding)
if doCacheReadEntityBytes {
return r.cachingReadEntity(contentType, contentEncoding, entityPointer)
}
// unmarshall directly from request Body
return r.decodeEntity(r.Request.Body, contentType, contentEncoding, entityPointer)
}
func (r *Request) cachingReadEntity(contentType string, contentEncoding string, entityPointer interface{}) (err error) { // OLD feature, cache the body for reads
var buffer []byte if doCacheReadEntityBytes {
if r.bodyContent != nil { if r.bodyContent == nil {
buffer = *r.bodyContent data, err := ioutil.ReadAll(r.Request.Body)
} else {
buffer, err = ioutil.ReadAll(r.Request.Body)
if err != nil { if err != nil {
return err return err
} }
r.bodyContent = &buffer r.bodyContent = &data
}
r.Request.Body = ioutil.NopCloser(bytes.NewReader(*r.bodyContent))
} }
return r.decodeEntity(bytes.NewReader(buffer), contentType, contentEncoding, entityPointer)
}
func (r *Request) decodeEntity(reader io.Reader, contentType string, contentEncoding string, entityPointer interface{}) (err error) {
entityReader := reader
// check if the request body needs decompression // check if the request body needs decompression
if ENCODING_GZIP == contentEncoding { if ENCODING_GZIP == contentEncoding {
gzipReader := GzipReaderPool.Get().(*gzip.Reader) gzipReader := currentCompressorProvider.AcquireGzipReader()
gzipReader.Reset(reader) defer currentCompressorProvider.ReleaseGzipReader(gzipReader)
entityReader = gzipReader gzipReader.Reset(r.Request.Body)
r.Request.Body = gzipReader
} else if ENCODING_DEFLATE == contentEncoding { } else if ENCODING_DEFLATE == contentEncoding {
zlibReader, err := zlib.NewReader(reader) zlibReader, err := zlib.NewReader(r.Request.Body)
if err != nil { if err != nil {
return err return err
} }
entityReader = zlibReader r.Request.Body = zlibReader
}
// decode JSON
if strings.Contains(contentType, MIME_JSON) || MIME_JSON == defaultRequestContentType {
decoder := json.NewDecoder(entityReader)
decoder.UseNumber()
return decoder.Decode(entityPointer)
}
// decode XML
if strings.Contains(contentType, MIME_XML) || MIME_XML == defaultRequestContentType {
return xml.NewDecoder(entityReader).Decode(entityPointer)
} }
// lookup the EntityReader
entityReader, ok := entityAccessRegistry.accessorAt(contentType)
if !ok {
return NewError(http.StatusBadRequest, "Unable to unmarshal content of type:"+contentType) return NewError(http.StatusBadRequest, "Unable to unmarshal content of type:"+contentType)
}
return entityReader.Read(r, entityPointer)
} }
// SetAttribute adds or replaces the attribute with the given value. // SetAttribute adds or replaces the attribute with the given value.

View file

@ -5,18 +5,14 @@ package restful
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
import ( import (
"encoding/json" "errors"
"encoding/xml"
"net/http" "net/http"
"strings"
) )
// DEPRECATED, use DefaultResponseContentType(mime) // DEPRECATED, use DefaultResponseContentType(mime)
var DefaultResponseMimeType string var DefaultResponseMimeType string
//PrettyPrintResponses controls the indentation feature of XML and JSON //PrettyPrintResponses controls the indentation feature of XML and JSON serialization
//serialization in the response methods WriteEntity, WriteAsJson, and
//WriteAsXml.
var PrettyPrintResponses = true var PrettyPrintResponses = true
// Response is a wrapper on the actual http ResponseWriter // Response is a wrapper on the actual http ResponseWriter
@ -36,8 +32,7 @@ func NewResponse(httpWriter http.ResponseWriter) *Response {
return &Response{httpWriter, "", []string{}, http.StatusOK, 0, PrettyPrintResponses, nil} // empty content-types return &Response{httpWriter, "", []string{}, http.StatusOK, 0, PrettyPrintResponses, nil} // empty content-types
} }
// If Accept header matching fails, fall back to this type, otherwise // If Accept header matching fails, fall back to this type.
// a "406: Not Acceptable" response is returned.
// Valid values are restful.MIME_JSON and restful.MIME_XML // Valid values are restful.MIME_JSON and restful.MIME_XML
// Example: // Example:
// restful.DefaultResponseContentType(restful.MIME_JSON) // restful.DefaultResponseContentType(restful.MIME_JSON)
@ -68,117 +63,100 @@ func (r *Response) SetRequestAccepts(mime string) {
r.requestAccept = mime r.requestAccept = mime
} }
// WriteEntity marshals the value using the representation denoted by the Accept Header (XML or JSON) // EntityWriter returns the registered EntityWriter that the entity (requested resource)
// If no Accept header is specified (or */*) then return the Content-Type as specified by the first in the Route.Produces. // can write according to what the request wants (Accept) and what the Route can produce or what the restful defaults say.
// If an Accept header is specified then return the Content-Type as specified by the first in the Route.Produces that is matched with the Accept header. // If called before WriteEntity and WriteHeader then a false return value can be used to write a 406: Not Acceptable.
// If the value is nil then nothing is written. You may want to call WriteHeader(http.StatusNotFound) instead. func (r *Response) EntityWriter() (EntityReaderWriter, bool) {
// Current implementation ignores any q-parameters in the Accept Header. sorted := sortedMimes(r.requestAccept)
func (r *Response) WriteEntity(value interface{}) error { for _, eachAccept := range sorted {
if value == nil { // do not write a nil representation for _, eachProduce := range r.routeProduces {
return nil if eachProduce == eachAccept.media {
if w, ok := entityAccessRegistry.accessorAt(eachAccept.media); ok {
return w, true
} }
for _, qualifiedMime := range strings.Split(r.requestAccept, ",") { }
mime := strings.Trim(strings.Split(qualifiedMime, ";")[0], " ") }
if 0 == len(mime) || mime == "*/*" { if eachAccept.media == "*/*" {
for _, each := range r.routeProduces { for _, each := range r.routeProduces {
if MIME_JSON == each { if w, ok := entityAccessRegistry.accessorAt(each); ok {
return r.WriteAsJson(value) return w, true
}
if MIME_XML == each {
return r.WriteAsXml(value)
}
}
} else { // mime is not blank; see if we have a match in Produces
for _, each := range r.routeProduces {
if mime == each {
if MIME_JSON == each {
return r.WriteAsJson(value)
}
if MIME_XML == each {
return r.WriteAsXml(value)
}
} }
} }
} }
} }
// if requestAccept is empty
writer, ok := entityAccessRegistry.accessorAt(r.requestAccept)
if !ok {
// if not registered then fallback to the defaults (if set)
if DefaultResponseMimeType == MIME_JSON { if DefaultResponseMimeType == MIME_JSON {
return r.WriteAsJson(value) return entityAccessRegistry.accessorAt(MIME_JSON)
} else if DefaultResponseMimeType == MIME_XML { }
return r.WriteAsXml(value) if DefaultResponseMimeType == MIME_XML {
} else { return entityAccessRegistry.accessorAt(MIME_XML)
}
// Fallback to whatever the route says it can produce.
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
for _, each := range r.routeProduces {
if w, ok := entityAccessRegistry.accessorAt(each); ok {
return w, true
}
}
if trace { if trace {
traceLogger.Printf("mismatch in mime-types and no defaults; (http)Accept=%v,(route)Produces=%v\n", r.requestAccept, r.routeProduces) traceLogger.Printf("no registered EntityReaderWriter found for %s", r.requestAccept)
}
r.WriteHeader(http.StatusNotAcceptable) // for recording only
r.ResponseWriter.WriteHeader(http.StatusNotAcceptable)
if _, err := r.Write([]byte("406: Not Acceptable")); err != nil {
return err
} }
} }
return writer, ok
}
// WriteEntity calls WriteHeaderAndEntity with Http Status OK (200)
func (r *Response) WriteEntity(value interface{}) error {
return r.WriteHeaderAndEntity(http.StatusOK, value)
}
// WriteHeaderAndEntity marshals the value using the representation denoted by the Accept Header and the registered EntityWriters.
// If no Accept header is specified (or */*) then respond with the Content-Type as specified by the first in the Route.Produces.
// If an Accept header is specified then respond with the Content-Type as specified by the first in the Route.Produces that is matched with the Accept header.
// If the value is nil then no response is send except for the Http status. You may want to call WriteHeader(http.StatusNotFound) instead.
// If there is no writer available that can represent the value in the requested MIME type then Http Status NotAcceptable is written.
// Current implementation ignores any q-parameters in the Accept Header.
// Returns an error if the value could not be written on the response.
func (r *Response) WriteHeaderAndEntity(status int, value interface{}) error {
writer, ok := r.EntityWriter()
if !ok {
r.WriteHeader(http.StatusNotAcceptable)
return nil return nil
}
return writer.Write(r, status, value)
} }
// WriteAsXml is a convenience method for writing a value in xml (requires Xml tags on the value) // WriteAsXml is a convenience method for writing a value in xml (requires Xml tags on the value)
// It uses the standard encoding/xml package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteAsXml(value interface{}) error { func (r *Response) WriteAsXml(value interface{}) error {
var output []byte return writeXML(r, http.StatusOK, MIME_XML, value)
var err error
if value == nil { // do not write a nil representation
return nil
}
if r.prettyPrint {
output, err = xml.MarshalIndent(value, " ", " ")
} else {
output, err = xml.Marshal(value)
}
if err != nil {
return r.WriteError(http.StatusInternalServerError, err)
}
r.Header().Set(HEADER_ContentType, MIME_XML)
if r.statusCode > 0 { // a WriteHeader was intercepted
r.ResponseWriter.WriteHeader(r.statusCode)
}
_, err = r.Write([]byte(xml.Header))
if err != nil {
return err
}
if _, err = r.Write(output); err != nil {
return err
}
return nil
} }
// WriteAsJson is a convenience method for writing a value in json // WriteHeaderAndXml is a convenience method for writing a status and value in xml (requires Xml tags on the value)
// It uses the standard encoding/xml package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteHeaderAndXml(status int, value interface{}) error {
return writeXML(r, status, MIME_XML, value)
}
// WriteAsJson is a convenience method for writing a value in json.
// It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteAsJson(value interface{}) error { func (r *Response) WriteAsJson(value interface{}) error {
return r.WriteJson(value, MIME_JSON) // no charset return writeJSON(r, http.StatusOK, MIME_JSON, value)
} }
// WriteJson is a convenience method for writing a value in Json with a given Content-Type // WriteJson is a convenience method for writing a value in Json with a given Content-Type.
// It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
func (r *Response) WriteJson(value interface{}, contentType string) error { func (r *Response) WriteJson(value interface{}, contentType string) error {
var output []byte return writeJSON(r, http.StatusOK, contentType, value)
var err error }
if value == nil { // do not write a nil representation // WriteHeaderAndJson is a convenience method for writing the status and a value in Json with a given Content-Type.
return nil // It uses the standard encoding/json package for marshalling the value ; not using a registered EntityReaderWriter.
} func (r *Response) WriteHeaderAndJson(status int, value interface{}, contentType string) error {
if r.prettyPrint { return writeJSON(r, status, contentType, value)
output, err = json.MarshalIndent(value, " ", " ")
} else {
output, err = json.Marshal(value)
}
if err != nil {
return r.WriteErrorString(http.StatusInternalServerError, err.Error())
}
r.Header().Set(HEADER_ContentType, contentType)
if r.statusCode > 0 { // a WriteHeader was intercepted
r.ResponseWriter.WriteHeader(r.statusCode)
}
if _, err = r.Write(output); err != nil {
return err
}
return nil
} }
// WriteError write the http status and the error string on the response. // WriteError write the http status and the error string on the response.
@ -187,50 +165,42 @@ func (r *Response) WriteError(httpStatus int, err error) error {
return r.WriteErrorString(httpStatus, err.Error()) return r.WriteErrorString(httpStatus, err.Error())
} }
// WriteServiceError is a convenience method for a responding with a ServiceError and a status // WriteServiceError is a convenience method for a responding with a status and a ServiceError
func (r *Response) WriteServiceError(httpStatus int, err ServiceError) error { func (r *Response) WriteServiceError(httpStatus int, err ServiceError) error {
r.WriteHeader(httpStatus) // for recording only r.err = err
return r.WriteEntity(err) return r.WriteHeaderAndEntity(httpStatus, err)
} }
// WriteErrorString is a convenience method for an error status with the actual error // WriteErrorString is a convenience method for an error status with the actual error
func (r *Response) WriteErrorString(status int, errorReason string) error { func (r *Response) WriteErrorString(httpStatus int, errorReason string) error {
r.statusCode = status // for recording only if r.err == nil {
r.ResponseWriter.WriteHeader(status) // if not called from WriteError
r.err = errors.New(errorReason)
}
r.WriteHeader(httpStatus)
if _, err := r.Write([]byte(errorReason)); err != nil { if _, err := r.Write([]byte(errorReason)); err != nil {
return err return err
} }
return nil return nil
} }
// WriteHeader is overridden to remember the Status Code that has been written. // Flush implements http.Flusher interface, which sends any buffered data to the client.
// Note that using this method, the status value is only written when func (r *Response) Flush() {
// calling WriteEntity, if f, ok := r.ResponseWriter.(http.Flusher); ok {
// or directly calling WriteAsXml or WriteAsJson, f.Flush()
// or if the status is one for which no response is allowed: } else if trace {
// traceLogger.Printf("ResponseWriter %v doesn't support Flush", r)
// 202 = http.StatusAccepted
// 204 = http.StatusNoContent
// 206 = http.StatusPartialContent
// 304 = http.StatusNotModified
// 404 = http.StatusNotFound
//
// If this behavior does not fit your need then you can write to the underlying response, such as:
// response.ResponseWriter.WriteHeader(http.StatusAccepted)
func (r *Response) WriteHeader(httpStatus int) {
r.statusCode = httpStatus
// if 202,204,206,304,404 then WriteEntity will not be called so we need to pass this code
if http.StatusNotFound == httpStatus ||
http.StatusNoContent == httpStatus ||
http.StatusNotModified == httpStatus ||
http.StatusPartialContent == httpStatus ||
http.StatusAccepted == httpStatus {
r.ResponseWriter.WriteHeader(httpStatus)
} }
} }
// WriteHeader is overridden to remember the Status Code that has been written.
// Changes to the Header of the response have no effect after this.
func (r *Response) WriteHeader(httpStatus int) {
r.statusCode = httpStatus
r.ResponseWriter.WriteHeader(httpStatus)
}
// StatusCode returns the code that has been written using WriteHeader. // StatusCode returns the code that has been written using WriteHeader.
// If WriteHeader, WriteEntity or WriteAsXml has not been called (yet) then return 200 OK.
func (r Response) StatusCode() int { func (r Response) StatusCode() int {
if 0 == r.statusCode { if 0 == r.statusCode {
// no status code has been written yet; assume OK // no status code has been written yet; assume OK

View file

@ -128,7 +128,7 @@ func (b *RouteBuilder) Param(parameter *Parameter) *RouteBuilder {
return b return b
} }
// Operation allows you to document what the acutal method/function call is of the Route. // Operation allows you to document what the actual method/function call is of the Route.
// Unless called, the operation name is derived from the RouteFunction set using To(..). // Unless called, the operation name is derived from the RouteFunction set using To(..).
func (b *RouteBuilder) Operation(name string) *RouteBuilder { func (b *RouteBuilder) Operation(name string) *RouteBuilder {
b.operation = name b.operation = name

View file

@ -9,6 +9,8 @@ import (
// PostBuildDeclarationMapFunc can be used to modify the api declaration map. // PostBuildDeclarationMapFunc can be used to modify the api declaration map.
type PostBuildDeclarationMapFunc func(apiDeclarationMap *ApiDeclarationList) type PostBuildDeclarationMapFunc func(apiDeclarationMap *ApiDeclarationList)
type MapSchemaFormatFunc func(typeName string) string
type Config struct { type Config struct {
// url where the services are available, e.g. http://localhost:8080 // url where the services are available, e.g. http://localhost:8080
// if left empty then the basePath of Swagger is taken from the actual request // if left empty then the basePath of Swagger is taken from the actual request
@ -29,4 +31,8 @@ type Config struct {
ApiVersion string ApiVersion string
// If set then call this handler after building the complete ApiDeclaration Map // If set then call this handler after building the complete ApiDeclaration Map
PostBuildHandler PostBuildDeclarationMapFunc PostBuildHandler PostBuildDeclarationMapFunc
// Swagger global info struct
Info Info
// [optional] If set, model builder should call this handler to get addition typename-to-swagger-format-field convertion.
SchemaFormatHandler MapSchemaFormatFunc
} }

View file

@ -14,6 +14,7 @@ type ModelBuildable interface {
type modelBuilder struct { type modelBuilder struct {
Models *ModelList Models *ModelList
Config *Config
} }
type documentable interface { type documentable interface {
@ -50,6 +51,14 @@ func (b modelBuilder) addModel(st reflect.Type, nameOverride string) *Model {
if b.isPrimitiveType(modelName) { if b.isPrimitiveType(modelName) {
return nil return nil
} }
// golang encoding/json packages says array and slice values encode as
// JSON arrays, except that []byte encodes as a base64-encoded string.
// If we see a []byte here, treat it at as a primitive type (string)
// and deal with it in buildArrayTypeProperty.
if (st.Kind() == reflect.Slice || st.Kind() == reflect.Array) &&
st.Elem().Kind() == reflect.Uint8 {
return nil
}
// see if we already have visited this model // see if we already have visited this model
if _, ok := b.Models.At(modelName); ok { if _, ok := b.Models.At(modelName); ok {
return nil return nil
@ -132,9 +141,11 @@ func (b modelBuilder) buildProperty(field reflect.StructField, model *Model, mod
modelDescription = tag modelDescription = tag
} }
fieldType := field.Type
prop.setPropertyMetadata(field) prop.setPropertyMetadata(field)
if prop.Type != nil {
return jsonName, modelDescription, prop
}
fieldType := field.Type
// check if type is doing its own marshalling // check if type is doing its own marshalling
marshalerType := reflect.TypeOf((*json.Marshaler)(nil)).Elem() marshalerType := reflect.TypeOf((*json.Marshaler)(nil)).Elem()
@ -176,8 +187,8 @@ func (b modelBuilder) buildProperty(field reflect.StructField, model *Model, mod
return jsonName, modelDescription, prop return jsonName, modelDescription, prop
case fieldKind == reflect.Map: case fieldKind == reflect.Map:
// if it's a map, it's unstructured, and swagger 1.2 can't handle it // if it's a map, it's unstructured, and swagger 1.2 can't handle it
anyt := "any" objectType := "object"
prop.Type = &anyt prop.Type = &objectType
return jsonName, modelDescription, prop return jsonName, modelDescription, prop
} }
@ -212,8 +223,12 @@ func hasNamedJSONTag(field reflect.StructField) bool {
} }
func (b modelBuilder) buildStructTypeProperty(field reflect.StructField, jsonName string, model *Model) (nameJson string, prop ModelProperty) { func (b modelBuilder) buildStructTypeProperty(field reflect.StructField, jsonName string, model *Model) (nameJson string, prop ModelProperty) {
fieldType := field.Type
prop.setPropertyMetadata(field) prop.setPropertyMetadata(field)
// Check for type override in tag
if prop.Type != nil {
return jsonName, prop
}
fieldType := field.Type
// check for anonymous // check for anonymous
if len(fieldType.Name()) == 0 { if len(fieldType.Name()) == 0 {
// anonymous // anonymous
@ -225,7 +240,7 @@ func (b modelBuilder) buildStructTypeProperty(field reflect.StructField, jsonNam
if field.Name == fieldType.Name() && field.Anonymous && !hasNamedJSONTag(field) { if field.Name == fieldType.Name() && field.Anonymous && !hasNamedJSONTag(field) {
// embedded struct // embedded struct
sub := modelBuilder{new(ModelList)} sub := modelBuilder{new(ModelList), b.Config}
sub.addModel(fieldType, "") sub.addModel(fieldType, "")
subKey := sub.keyFrom(fieldType) subKey := sub.keyFrom(fieldType)
// merge properties from sub // merge properties from sub
@ -263,13 +278,23 @@ func (b modelBuilder) buildStructTypeProperty(field reflect.StructField, jsonNam
} }
func (b modelBuilder) buildArrayTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop ModelProperty) { func (b modelBuilder) buildArrayTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop ModelProperty) {
fieldType := field.Type // check for type override in tags
prop.setPropertyMetadata(field) prop.setPropertyMetadata(field)
if prop.Type != nil {
return jsonName, prop
}
fieldType := field.Type
if fieldType.Elem().Kind() == reflect.Uint8 {
stringt := "string"
prop.Type = &stringt
return jsonName, prop
}
var pType = "array" var pType = "array"
prop.Type = &pType prop.Type = &pType
isPrimitive := b.isPrimitiveType(fieldType.Elem().Name())
elemTypeName := b.getElementTypeName(modelName, jsonName, fieldType.Elem()) elemTypeName := b.getElementTypeName(modelName, jsonName, fieldType.Elem())
prop.Items = new(Item) prop.Items = new(Item)
if b.isPrimitiveType(elemTypeName) { if isPrimitive {
mapped := b.jsonSchemaType(elemTypeName) mapped := b.jsonSchemaType(elemTypeName)
prop.Items.Type = &mapped prop.Items.Type = &mapped
} else { } else {
@ -279,22 +304,36 @@ func (b modelBuilder) buildArrayTypeProperty(field reflect.StructField, jsonName
if fieldType.Elem().Kind() == reflect.Ptr { if fieldType.Elem().Kind() == reflect.Ptr {
fieldType = fieldType.Elem() fieldType = fieldType.Elem()
} }
if !isPrimitive {
b.addModel(fieldType.Elem(), elemTypeName) b.addModel(fieldType.Elem(), elemTypeName)
}
return jsonName, prop return jsonName, prop
} }
func (b modelBuilder) buildPointerTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop ModelProperty) { func (b modelBuilder) buildPointerTypeProperty(field reflect.StructField, jsonName, modelName string) (nameJson string, prop ModelProperty) {
fieldType := field.Type
prop.setPropertyMetadata(field) prop.setPropertyMetadata(field)
// Check for type override in tags
if prop.Type != nil {
return jsonName, prop
}
fieldType := field.Type
// override type of pointer to list-likes // override type of pointer to list-likes
if fieldType.Elem().Kind() == reflect.Slice || fieldType.Elem().Kind() == reflect.Array { if fieldType.Elem().Kind() == reflect.Slice || fieldType.Elem().Kind() == reflect.Array {
var pType = "array" var pType = "array"
prop.Type = &pType prop.Type = &pType
isPrimitive := b.isPrimitiveType(fieldType.Elem().Elem().Name())
elemName := b.getElementTypeName(modelName, jsonName, fieldType.Elem().Elem()) elemName := b.getElementTypeName(modelName, jsonName, fieldType.Elem().Elem())
if isPrimitive {
primName := b.jsonSchemaType(elemName)
prop.Items = &Item{Ref: &primName}
} else {
prop.Items = &Item{Ref: &elemName} prop.Items = &Item{Ref: &elemName}
}
if !isPrimitive {
// add|overwrite model for element type // add|overwrite model for element type
b.addModel(fieldType.Elem().Elem(), elemName) b.addModel(fieldType.Elem().Elem(), elemName)
}
} else { } else {
// non-array, pointer type // non-array, pointer type
var pType = b.jsonSchemaType(fieldType.String()[1:]) // no star, include pkg path var pType = b.jsonSchemaType(fieldType.String()[1:]) // no star, include pkg path
@ -321,9 +360,6 @@ func (b modelBuilder) getElementTypeName(modelName, jsonName string, t reflect.T
if t.Name() == "" { if t.Name() == "" {
return modelName + "." + jsonName return modelName + "." + jsonName
} }
if b.isPrimitiveType(t.Name()) {
return b.jsonSchemaType(t.Name())
}
return b.keyFrom(t) return b.keyFrom(t)
} }
@ -338,7 +374,10 @@ func (b modelBuilder) keyFrom(st reflect.Type) string {
// see also https://golang.org/ref/spec#Numeric_types // see also https://golang.org/ref/spec#Numeric_types
func (b modelBuilder) isPrimitiveType(modelName string) bool { func (b modelBuilder) isPrimitiveType(modelName string) bool {
return strings.Contains("uint8 uint16 uint32 uint64 int int8 int16 int32 int64 float32 float64 bool string byte rune time.Time", modelName) if len(modelName) == 0 {
return false
}
return strings.Contains("uint uint8 uint16 uint32 uint64 int int8 int16 int32 int64 float32 float64 bool string byte rune time.Time", modelName)
} }
// jsonNameOfField returns the name of the field as it should appear in JSON format // jsonNameOfField returns the name of the field as it should appear in JSON format
@ -359,6 +398,7 @@ func (b modelBuilder) jsonNameOfField(field reflect.StructField) string {
// see also http://json-schema.org/latest/json-schema-core.html#anchor8 // see also http://json-schema.org/latest/json-schema-core.html#anchor8
func (b modelBuilder) jsonSchemaType(modelName string) string { func (b modelBuilder) jsonSchemaType(modelName string) string {
schemaMap := map[string]string{ schemaMap := map[string]string{
"uint": "integer",
"uint8": "integer", "uint8": "integer",
"uint16": "integer", "uint16": "integer",
"uint32": "integer", "uint32": "integer",
@ -384,11 +424,17 @@ func (b modelBuilder) jsonSchemaType(modelName string) string {
} }
func (b modelBuilder) jsonSchemaFormat(modelName string) string { func (b modelBuilder) jsonSchemaFormat(modelName string) string {
if b.Config != nil && b.Config.SchemaFormatHandler != nil {
if mapped := b.Config.SchemaFormatHandler(modelName); mapped != "" {
return mapped
}
}
schemaMap := map[string]string{ schemaMap := map[string]string{
"int": "int32", "int": "int32",
"int32": "int32", "int32": "int32",
"int64": "int64", "int64": "int64",
"byte": "byte", "byte": "byte",
"uint": "integer",
"uint8": "byte", "uint8": "byte",
"float64": "double", "float64": "double",
"float32": "float", "float32": "float",

View file

@ -31,6 +31,12 @@ func (prop *ModelProperty) setMaximum(field reflect.StructField) {
} }
} }
func (prop *ModelProperty) setType(field reflect.StructField) {
if tag := field.Tag.Get("type"); tag != "" {
prop.Type = &tag
}
}
func (prop *ModelProperty) setMinimum(field reflect.StructField) { func (prop *ModelProperty) setMinimum(field reflect.StructField) {
if tag := field.Tag.Get("minimum"); tag != "" { if tag := field.Tag.Get("minimum"); tag != "" {
prop.Minimum = tag prop.Minimum = tag
@ -56,4 +62,5 @@ func (prop *ModelProperty) setPropertyMetadata(field reflect.StructField) {
prop.setMaximum(field) prop.setMaximum(field)
prop.setUniqueItems(field) prop.setUniqueItems(field)
prop.setDefaultValue(field) prop.setDefaultValue(field)
prop.setType(field)
} }

View file

@ -48,7 +48,7 @@ type Info struct {
TermsOfServiceUrl string `json:"termsOfServiceUrl,omitempty"` TermsOfServiceUrl string `json:"termsOfServiceUrl,omitempty"`
Contact string `json:"contact,omitempty"` Contact string `json:"contact,omitempty"`
License string `json:"license,omitempty"` License string `json:"license,omitempty"`
LicensUrl string `json:"licensUrl,omitempty"` LicenseUrl string `json:"licenseUrl,omitempty"`
} }
// 5.1.5 // 5.1.5
@ -118,6 +118,7 @@ type ApiDeclaration struct {
ApiVersion string `json:"apiVersion"` ApiVersion string `json:"apiVersion"`
BasePath string `json:"basePath"` BasePath string `json:"basePath"`
ResourcePath string `json:"resourcePath"` // must start with / ResourcePath string `json:"resourcePath"` // must start with /
Info Info `json:"info"`
Apis []Api `json:"apis,omitempty"` Apis []Api `json:"apis,omitempty"`
Models ModelList `json:"models,omitempty"` Models ModelList `json:"models,omitempty"`
Produces []string `json:"produces,omitempty"` Produces []string `json:"produces,omitempty"`
@ -134,7 +135,7 @@ type Api struct {
// 5.2.3 Operation Object // 5.2.3 Operation Object
type Operation struct { type Operation struct {
Type string `json:"type"` DataTypeFields
Method string `json:"method"` Method string `json:"method"`
Summary string `json:"summary,omitempty"` Summary string `json:"summary,omitempty"`
Notes string `json:"notes,omitempty"` Notes string `json:"notes,omitempty"`

View file

@ -0,0 +1,21 @@
package swagger
type SwaggerBuilder struct {
SwaggerService
}
func NewSwaggerBuilder(config Config) *SwaggerBuilder {
return &SwaggerBuilder{*newSwaggerService(config)}
}
func (sb SwaggerBuilder) ProduceListing() ResourceListing {
return sb.SwaggerService.produceListing()
}
func (sb SwaggerBuilder) ProduceAllDeclarations() map[string]ApiDeclaration {
return sb.SwaggerService.produceAllDeclarations()
}
func (sb SwaggerBuilder) ProduceDeclarations(route string) (*ApiDeclaration, bool) {
return sb.SwaggerService.produceDeclarations(route)
}

View file

@ -19,9 +19,35 @@ type SwaggerService struct {
} }
func newSwaggerService(config Config) *SwaggerService { func newSwaggerService(config Config) *SwaggerService {
return &SwaggerService{ sws := &SwaggerService{
config: config, config: config,
apiDeclarationMap: new(ApiDeclarationList)} apiDeclarationMap: new(ApiDeclarationList)}
// Build all ApiDeclarations
for _, each := range config.WebServices {
rootPath := each.RootPath()
// skip the api service itself
if rootPath != config.ApiPath {
if rootPath == "" || rootPath == "/" {
// use routes
for _, route := range each.Routes() {
entry := staticPathFromRoute(route)
_, exists := sws.apiDeclarationMap.At(entry)
if !exists {
sws.apiDeclarationMap.Put(entry, sws.composeDeclaration(each, entry))
}
}
} else { // use root path
sws.apiDeclarationMap.Put(each.RootPath(), sws.composeDeclaration(each, each.RootPath()))
}
}
}
// if specified then call the PostBuilderHandler
if config.PostBuildHandler != nil {
config.PostBuildHandler(sws.apiDeclarationMap)
}
return sws
} }
// LogInfo is the function that is called when this package needs to log. It defaults to log.Printf // LogInfo is the function that is called when this package needs to log. It defaults to log.Printf
@ -57,31 +83,6 @@ func RegisterSwaggerService(config Config, wsContainer *restful.Container) {
LogInfo("[restful/swagger] listing is available at %v%v", config.WebServicesUrl, config.ApiPath) LogInfo("[restful/swagger] listing is available at %v%v", config.WebServicesUrl, config.ApiPath)
wsContainer.Add(ws) wsContainer.Add(ws)
// Build all ApiDeclarations
for _, each := range config.WebServices {
rootPath := each.RootPath()
// skip the api service itself
if rootPath != config.ApiPath {
if rootPath == "" || rootPath == "/" {
// use routes
for _, route := range each.Routes() {
entry := staticPathFromRoute(route)
_, exists := sws.apiDeclarationMap.At(entry)
if !exists {
sws.apiDeclarationMap.Put(entry, sws.composeDeclaration(each, entry))
}
}
} else { // use root path
sws.apiDeclarationMap.Put(each.RootPath(), sws.composeDeclaration(each, each.RootPath()))
}
}
}
// if specified then call the PostBuilderHandler
if config.PostBuildHandler != nil {
config.PostBuildHandler(sws.apiDeclarationMap)
}
// Check paths for UI serving // Check paths for UI serving
if config.StaticHandler == nil && config.SwaggerFilePath != "" && config.SwaggerPath != "" { if config.StaticHandler == nil && config.SwaggerFilePath != "" && config.SwaggerPath != "" {
swaggerPathSlash := config.SwaggerPath swaggerPathSlash := config.SwaggerPath
@ -138,7 +139,12 @@ func enableCORS(req *restful.Request, resp *restful.Response, chain *restful.Fil
} }
func (sws SwaggerService) getListing(req *restful.Request, resp *restful.Response) { func (sws SwaggerService) getListing(req *restful.Request, resp *restful.Response) {
listing := ResourceListing{SwaggerVersion: swaggerVersion, ApiVersion: sws.config.ApiVersion} listing := sws.produceListing()
resp.WriteAsJson(listing)
}
func (sws SwaggerService) produceListing() ResourceListing {
listing := ResourceListing{SwaggerVersion: swaggerVersion, ApiVersion: sws.config.ApiVersion, Info: sws.config.Info}
sws.apiDeclarationMap.Do(func(k string, v ApiDeclaration) { sws.apiDeclarationMap.Do(func(k string, v ApiDeclaration) {
ref := Resource{Path: k} ref := Resource{Path: k}
if len(v.Apis) > 0 { // use description of first (could still be empty) if len(v.Apis) > 0 { // use description of first (could still be empty)
@ -146,11 +152,11 @@ func (sws SwaggerService) getListing(req *restful.Request, resp *restful.Respons
} }
listing.Apis = append(listing.Apis, ref) listing.Apis = append(listing.Apis, ref)
}) })
resp.WriteAsJson(listing) return listing
} }
func (sws SwaggerService) getDeclarations(req *restful.Request, resp *restful.Response) { func (sws SwaggerService) getDeclarations(req *restful.Request, resp *restful.Response) {
decl, ok := sws.apiDeclarationMap.At(composeRootPath(req)) decl, ok := sws.produceDeclarations(composeRootPath(req))
if !ok { if !ok {
resp.WriteErrorString(http.StatusNotFound, "ApiDeclaration not found") resp.WriteErrorString(http.StatusNotFound, "ApiDeclaration not found")
return return
@ -180,11 +186,28 @@ func (sws SwaggerService) getDeclarations(req *restful.Request, resp *restful.Re
scheme = "https" scheme = "https"
} }
} }
(&decl).BasePath = fmt.Sprintf("%s://%s", scheme, host) decl.BasePath = fmt.Sprintf("%s://%s", scheme, host)
} }
resp.WriteAsJson(decl) resp.WriteAsJson(decl)
} }
func (sws SwaggerService) produceAllDeclarations() map[string]ApiDeclaration {
decls := map[string]ApiDeclaration{}
sws.apiDeclarationMap.Do(func(k string, v ApiDeclaration) {
decls[k] = v
})
return decls
}
func (sws SwaggerService) produceDeclarations(route string) (*ApiDeclaration, bool) {
decl, ok := sws.apiDeclarationMap.At(route)
if !ok {
return nil, false
}
decl.BasePath = sws.config.WebServicesUrl
return &decl, true
}
// composeDeclaration uses all routes and parameters to create a ApiDeclaration // composeDeclaration uses all routes and parameters to create a ApiDeclaration
func (sws SwaggerService) composeDeclaration(ws *restful.WebService, pathPrefix string) ApiDeclaration { func (sws SwaggerService) composeDeclaration(ws *restful.WebService, pathPrefix string) ApiDeclaration {
decl := ApiDeclaration{ decl := ApiDeclaration{
@ -207,16 +230,18 @@ func (sws SwaggerService) composeDeclaration(ws *restful.WebService, pathPrefix
} }
} }
pathToRoutes.Do(func(path string, routes []restful.Route) { pathToRoutes.Do(func(path string, routes []restful.Route) {
api := Api{Path: strings.TrimSuffix(path, "/"), Description: ws.Documentation()} api := Api{Path: strings.TrimSuffix(withoutWildcard(path), "/"), Description: ws.Documentation()}
voidString := "void"
for _, route := range routes { for _, route := range routes {
operation := Operation{ operation := Operation{
Method: route.Method, Method: route.Method,
Summary: route.Doc, Summary: route.Doc,
Notes: route.Notes, Notes: route.Notes,
Type: asDataType(route.WriteSample), // Type gets overwritten if there is a write sample
DataTypeFields: DataTypeFields{Type: &voidString},
Parameters: []Parameter{}, Parameters: []Parameter{},
Nickname: route.Operation, Nickname: route.Operation,
ResponseMessages: composeResponseMessages(route, &decl)} ResponseMessages: composeResponseMessages(route, &decl, &sws.config)}
operation.Consumes = route.Consumes operation.Consumes = route.Consumes
operation.Produces = route.Produces operation.Produces = route.Produces
@ -238,8 +263,15 @@ func (sws SwaggerService) composeDeclaration(ws *restful.WebService, pathPrefix
return decl return decl
} }
func withoutWildcard(path string) string {
if strings.HasSuffix(path, ":*}") {
return path[0:len(path)-3] + "}"
}
return path
}
// composeResponseMessages takes the ResponseErrors (if any) and creates ResponseMessages from them. // composeResponseMessages takes the ResponseErrors (if any) and creates ResponseMessages from them.
func composeResponseMessages(route restful.Route, decl *ApiDeclaration) (messages []ResponseMessage) { func composeResponseMessages(route restful.Route, decl *ApiDeclaration, config *Config) (messages []ResponseMessage) {
if route.ResponseErrors == nil { if route.ResponseErrors == nil {
return messages return messages
} }
@ -262,7 +294,7 @@ func composeResponseMessages(route restful.Route, decl *ApiDeclaration) (message
if isCollection { if isCollection {
modelName = "array[" + modelName + "]" modelName = "array[" + modelName + "]"
} }
modelBuilder{&decl.Models}.addModel(st, "") modelBuilder{Models: &decl.Models, Config: config}.addModel(st, "")
// reference the model // reference the model
message.ResponseModel = modelName message.ResponseModel = modelName
} }
@ -299,23 +331,19 @@ func detectCollectionType(st reflect.Type) (bool, reflect.Type) {
// addModelFromSample creates and adds (or overwrites) a Model from a sample resource // addModelFromSample creates and adds (or overwrites) a Model from a sample resource
func (sws SwaggerService) addModelFromSampleTo(operation *Operation, isResponse bool, sample interface{}, models *ModelList) { func (sws SwaggerService) addModelFromSampleTo(operation *Operation, isResponse bool, sample interface{}, models *ModelList) {
st := reflect.TypeOf(sample)
isCollection, st := detectCollectionType(st)
modelName := modelBuilder{}.keyFrom(st)
if isResponse { if isResponse {
if isCollection { type_, items := asDataType(sample, &sws.config)
modelName = "array[" + modelName + "]" operation.Type = type_
operation.Items = items
} }
operation.Type = modelName modelBuilder{Models: models, Config: &sws.config}.addModelFrom(sample)
}
modelBuilder{models}.addModelFrom(sample)
} }
func asSwaggerParameter(param restful.ParameterData) Parameter { func asSwaggerParameter(param restful.ParameterData) Parameter {
return Parameter{ return Parameter{
DataTypeFields: DataTypeFields{ DataTypeFields: DataTypeFields{
Type: &param.DataType, Type: &param.DataType,
Format: asFormat(param.DataType), Format: asFormat(param.DataType, param.DataFormat),
DefaultValue: Special(param.DefaultValue), DefaultValue: Special(param.DefaultValue),
}, },
Name: param.Name, Name: param.Name,
@ -360,7 +388,10 @@ func composeRootPath(req *restful.Request) string {
return path + "/" + g return path + "/" + g
} }
func asFormat(name string) string { func asFormat(dataType string, dataFormat string) string {
if dataFormat != "" {
return dataFormat
}
return "" // TODO return "" // TODO
} }
@ -380,9 +411,30 @@ func asParamType(kind int) string {
return "" return ""
} }
func asDataType(any interface{}) string { func asDataType(any interface{}, config *Config) (*string, *Item) {
if any == nil { // If it's not a collection, return the suggested model name
return "void" st := reflect.TypeOf(any)
isCollection, st := detectCollectionType(st)
modelName := modelBuilder{}.keyFrom(st)
// if it's not a collection we are done
if !isCollection {
return &modelName, nil
} }
return reflect.TypeOf(any).Name()
// XXX: This is not very elegant
// We create an Item object referring to the given model
models := ModelList{}
mb := modelBuilder{Models: &models, Config: config}
mb.addModelFrom(any)
elemTypeName := mb.getElementTypeName(modelName, "", st)
item := new(Item)
if mb.isPrimitiveType(elemTypeName) {
mapped := mb.jsonSchemaType(elemTypeName)
item.Type = &mapped
} else {
item.Ref = &elemTypeName
}
tmp := "array"
return &tmp, item
} }

View file

@ -1,7 +1,7 @@
package restful package restful
import ( import (
"fmt" "errors"
"os" "os"
"sync" "sync"
@ -36,9 +36,6 @@ func (w *WebService) SetDynamicRoutes(enable bool) {
// compilePathExpression ensures that the path is compiled into a RegEx for those routers that need it. // compilePathExpression ensures that the path is compiled into a RegEx for those routers that need it.
func (w *WebService) compilePathExpression() { func (w *WebService) compilePathExpression() {
if len(w.rootPath) == 0 {
w.Path("/") // lazy initialize path
}
compiled, err := newPathExpression(w.rootPath) compiled, err := newPathExpression(w.rootPath)
if err != nil { if err != nil {
log.Printf("[restful] invalid path:%s because:%v", w.rootPath, err) log.Printf("[restful] invalid path:%s because:%v", w.rootPath, err)
@ -54,12 +51,15 @@ func (w *WebService) ApiVersion(apiVersion string) *WebService {
} }
// Version returns the API version for documentation purposes. // Version returns the API version for documentation purposes.
func (w WebService) Version() string { return w.apiVersion } func (w *WebService) Version() string { return w.apiVersion }
// Path specifies the root URL template path of the WebService. // Path specifies the root URL template path of the WebService.
// All Routes will be relative to this path. // All Routes will be relative to this path.
func (w *WebService) Path(root string) *WebService { func (w *WebService) Path(root string) *WebService {
w.rootPath = root w.rootPath = root
if len(w.rootPath) == 0 {
w.rootPath = "/"
}
w.compilePathExpression() w.compilePathExpression()
return w return w
} }
@ -155,15 +155,20 @@ func (w *WebService) Route(builder *RouteBuilder) *WebService {
// RemoveRoute removes the specified route, looks for something that matches 'path' and 'method' // RemoveRoute removes the specified route, looks for something that matches 'path' and 'method'
func (w *WebService) RemoveRoute(path, method string) error { func (w *WebService) RemoveRoute(path, method string) error {
if !w.dynamicRoutes { if !w.dynamicRoutes {
return fmt.Errorf("dynamic routes are not enabled.") return errors.New("dynamic routes are not enabled.")
} }
w.routesLock.Lock() w.routesLock.Lock()
defer w.routesLock.Unlock() defer w.routesLock.Unlock()
newRoutes := make([]Route, (len(w.routes) - 1))
current := 0
for ix := range w.routes { for ix := range w.routes {
if w.routes[ix].Method == method && w.routes[ix].Path == path { if w.routes[ix].Method == method && w.routes[ix].Path == path {
w.routes = append(w.routes[:ix], w.routes[ix+1:]...) continue
} }
newRoutes[current] = w.routes[ix]
current = current + 1
} }
w.routes = newRoutes
return nil return nil
} }
@ -187,7 +192,7 @@ func (w *WebService) Consumes(accepts ...string) *WebService {
} }
// Routes returns the Routes associated with this WebService // Routes returns the Routes associated with this WebService
func (w WebService) Routes() []Route { func (w *WebService) Routes() []Route {
if !w.dynamicRoutes { if !w.dynamicRoutes {
return w.routes return w.routes
} }
@ -202,12 +207,12 @@ func (w WebService) Routes() []Route {
} }
// RootPath returns the RootPath associated with this WebService. Default "/" // RootPath returns the RootPath associated with this WebService. Default "/"
func (w WebService) RootPath() string { func (w *WebService) RootPath() string {
return w.rootPath return w.rootPath
} }
// PathParameters return the path parameter names for (shared amoung its Routes) // PathParameters return the path parameter names for (shared amoung its Routes)
func (w WebService) PathParameters() []*Parameter { func (w *WebService) PathParameters() []*Parameter {
return w.pathParameters return w.pathParameters
} }
@ -224,7 +229,7 @@ func (w *WebService) Doc(plainText string) *WebService {
} }
// Documentation returns it. // Documentation returns it.
func (w WebService) Documentation() string { func (w *WebService) Documentation() string {
return w.documentation return w.documentation
} }