551 lines
16 KiB
Go
551 lines
16 KiB
Go
// Package httpcache provides a http.RoundTripper implementation that works as a
|
|
// mostly RFC-compliant cache for http responses.
|
|
//
|
|
// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client
|
|
// and not for a shared proxy).
|
|
//
|
|
package httpcache
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
stale = iota
|
|
fresh
|
|
transparent
|
|
// XFromCache is the header added to responses that are returned from the cache
|
|
XFromCache = "X-From-Cache"
|
|
)
|
|
|
|
// A Cache interface is used by the Transport to store and retrieve responses.
|
|
type Cache interface {
|
|
// Get returns the []byte representation of a cached response and a bool
|
|
// set to true if the value isn't empty
|
|
Get(key string) (responseBytes []byte, ok bool)
|
|
// Set stores the []byte representation of a response against a key
|
|
Set(key string, responseBytes []byte)
|
|
// Delete removes the value associated with the key
|
|
Delete(key string)
|
|
}
|
|
|
|
// cacheKey returns the cache key for req.
|
|
func cacheKey(req *http.Request) string {
|
|
if req.Method == http.MethodGet {
|
|
return req.URL.String()
|
|
} else {
|
|
return req.Method + " " + req.URL.String()
|
|
}
|
|
}
|
|
|
|
// CachedResponse returns the cached http.Response for req if present, and nil
|
|
// otherwise.
|
|
func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
|
|
cachedVal, ok := c.Get(cacheKey(req))
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
b := bytes.NewBuffer(cachedVal)
|
|
return http.ReadResponse(bufio.NewReader(b), req)
|
|
}
|
|
|
|
// MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
|
|
type MemoryCache struct {
|
|
mu sync.RWMutex
|
|
items map[string][]byte
|
|
}
|
|
|
|
// Get returns the []byte representation of the response and true if present, false if not
|
|
func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
|
|
c.mu.RLock()
|
|
resp, ok = c.items[key]
|
|
c.mu.RUnlock()
|
|
return resp, ok
|
|
}
|
|
|
|
// Set saves response resp to the cache with key
|
|
func (c *MemoryCache) Set(key string, resp []byte) {
|
|
c.mu.Lock()
|
|
c.items[key] = resp
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
// Delete removes key from the cache
|
|
func (c *MemoryCache) Delete(key string) {
|
|
c.mu.Lock()
|
|
delete(c.items, key)
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
// NewMemoryCache returns a new Cache that will store items in an in-memory map
|
|
func NewMemoryCache() *MemoryCache {
|
|
c := &MemoryCache{items: map[string][]byte{}}
|
|
return c
|
|
}
|
|
|
|
// Transport is an implementation of http.RoundTripper that will return values from a cache
|
|
// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since)
|
|
// to repeated requests allowing servers to return 304 / Not Modified
|
|
type Transport struct {
|
|
// The RoundTripper interface actually used to make requests
|
|
// If nil, http.DefaultTransport is used
|
|
Transport http.RoundTripper
|
|
Cache Cache
|
|
// If true, responses returned from the cache will be given an extra header, X-From-Cache
|
|
MarkCachedResponses bool
|
|
}
|
|
|
|
// NewTransport returns a new Transport with the
|
|
// provided Cache implementation and MarkCachedResponses set to true
|
|
func NewTransport(c Cache) *Transport {
|
|
return &Transport{Cache: c, MarkCachedResponses: true}
|
|
}
|
|
|
|
// Client returns an *http.Client that caches responses.
|
|
func (t *Transport) Client() *http.Client {
|
|
return &http.Client{Transport: t}
|
|
}
|
|
|
|
// varyMatches will return false unless all of the cached values for the headers listed in Vary
|
|
// match the new request
|
|
func varyMatches(cachedResp *http.Response, req *http.Request) bool {
|
|
for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
|
|
header = http.CanonicalHeaderKey(header)
|
|
if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// RoundTrip takes a Request and returns a Response
|
|
//
|
|
// If there is a fresh Response already in cache, then it will be returned without connecting to
|
|
// the server.
|
|
//
|
|
// If there is a stale Response, then any validators it contains will be set on the new request
|
|
// to give the server a chance to respond with NotModified. If this happens, then the cached Response
|
|
// will be returned.
|
|
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
|
cacheKey := cacheKey(req)
|
|
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
|
|
var cachedResp *http.Response
|
|
if cacheable {
|
|
cachedResp, err = CachedResponse(t.Cache, req)
|
|
} else {
|
|
// Need to invalidate an existing value
|
|
t.Cache.Delete(cacheKey)
|
|
}
|
|
|
|
transport := t.Transport
|
|
if transport == nil {
|
|
transport = http.DefaultTransport
|
|
}
|
|
|
|
if cacheable && cachedResp != nil && err == nil {
|
|
if t.MarkCachedResponses {
|
|
cachedResp.Header.Set(XFromCache, "1")
|
|
}
|
|
|
|
if varyMatches(cachedResp, req) {
|
|
// Can only use cached value if the new request doesn't Vary significantly
|
|
freshness := getFreshness(cachedResp.Header, req.Header)
|
|
if freshness == fresh {
|
|
return cachedResp, nil
|
|
}
|
|
|
|
if freshness == stale {
|
|
var req2 *http.Request
|
|
// Add validators if caller hasn't already done so
|
|
etag := cachedResp.Header.Get("etag")
|
|
if etag != "" && req.Header.Get("etag") == "" {
|
|
req2 = cloneRequest(req)
|
|
req2.Header.Set("if-none-match", etag)
|
|
}
|
|
lastModified := cachedResp.Header.Get("last-modified")
|
|
if lastModified != "" && req.Header.Get("last-modified") == "" {
|
|
if req2 == nil {
|
|
req2 = cloneRequest(req)
|
|
}
|
|
req2.Header.Set("if-modified-since", lastModified)
|
|
}
|
|
if req2 != nil {
|
|
req = req2
|
|
}
|
|
}
|
|
}
|
|
|
|
resp, err = transport.RoundTrip(req)
|
|
if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
|
|
// Replace the 304 response with the one from cache, but update with some new headers
|
|
endToEndHeaders := getEndToEndHeaders(resp.Header)
|
|
for _, header := range endToEndHeaders {
|
|
cachedResp.Header[header] = resp.Header[header]
|
|
}
|
|
resp = cachedResp
|
|
} else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
|
|
req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
|
|
// In case of transport failure and stale-if-error activated, returns cached content
|
|
// when available
|
|
return cachedResp, nil
|
|
} else {
|
|
if err != nil || resp.StatusCode != http.StatusOK {
|
|
t.Cache.Delete(cacheKey)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
} else {
|
|
reqCacheControl := parseCacheControl(req.Header)
|
|
if _, ok := reqCacheControl["only-if-cached"]; ok {
|
|
resp = newGatewayTimeoutResponse(req)
|
|
} else {
|
|
resp, err = transport.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
|
|
for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
|
|
varyKey = http.CanonicalHeaderKey(varyKey)
|
|
fakeHeader := "X-Varied-" + varyKey
|
|
reqValue := req.Header.Get(varyKey)
|
|
if reqValue != "" {
|
|
resp.Header.Set(fakeHeader, reqValue)
|
|
}
|
|
}
|
|
switch req.Method {
|
|
case "GET":
|
|
// Delay caching until EOF is reached.
|
|
resp.Body = &cachingReadCloser{
|
|
R: resp.Body,
|
|
OnEOF: func(r io.Reader) {
|
|
resp := *resp
|
|
resp.Body = ioutil.NopCloser(r)
|
|
respBytes, err := httputil.DumpResponse(&resp, true)
|
|
if err == nil {
|
|
t.Cache.Set(cacheKey, respBytes)
|
|
}
|
|
},
|
|
}
|
|
default:
|
|
respBytes, err := httputil.DumpResponse(resp, true)
|
|
if err == nil {
|
|
t.Cache.Set(cacheKey, respBytes)
|
|
}
|
|
}
|
|
} else {
|
|
t.Cache.Delete(cacheKey)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// ErrNoDateHeader indicates that the HTTP headers contained no Date header.
|
|
var ErrNoDateHeader = errors.New("no Date header")
|
|
|
|
// Date parses and returns the value of the Date header.
|
|
func Date(respHeaders http.Header) (date time.Time, err error) {
|
|
dateHeader := respHeaders.Get("date")
|
|
if dateHeader == "" {
|
|
err = ErrNoDateHeader
|
|
return
|
|
}
|
|
|
|
return time.Parse(time.RFC1123, dateHeader)
|
|
}
|
|
|
|
type realClock struct{}
|
|
|
|
func (c *realClock) since(d time.Time) time.Duration {
|
|
return time.Since(d)
|
|
}
|
|
|
|
type timer interface {
|
|
since(d time.Time) time.Duration
|
|
}
|
|
|
|
var clock timer = &realClock{}
|
|
|
|
// getFreshness will return one of fresh/stale/transparent based on the cache-control
|
|
// values of the request and the response
|
|
//
|
|
// fresh indicates the response can be returned
|
|
// stale indicates that the response needs validating before it is returned
|
|
// transparent indicates the response should not be used to fulfil the request
|
|
//
|
|
// Because this is only a private cache, 'public' and 'private' in cache-control aren't
|
|
// signficant. Similarly, smax-age isn't used.
|
|
func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
|
|
respCacheControl := parseCacheControl(respHeaders)
|
|
reqCacheControl := parseCacheControl(reqHeaders)
|
|
if _, ok := reqCacheControl["no-cache"]; ok {
|
|
return transparent
|
|
}
|
|
if _, ok := respCacheControl["no-cache"]; ok {
|
|
return stale
|
|
}
|
|
if _, ok := reqCacheControl["only-if-cached"]; ok {
|
|
return fresh
|
|
}
|
|
|
|
date, err := Date(respHeaders)
|
|
if err != nil {
|
|
return stale
|
|
}
|
|
currentAge := clock.since(date)
|
|
|
|
var lifetime time.Duration
|
|
var zeroDuration time.Duration
|
|
|
|
// If a response includes both an Expires header and a max-age directive,
|
|
// the max-age directive overrides the Expires header, even if the Expires header is more restrictive.
|
|
if maxAge, ok := respCacheControl["max-age"]; ok {
|
|
lifetime, err = time.ParseDuration(maxAge + "s")
|
|
if err != nil {
|
|
lifetime = zeroDuration
|
|
}
|
|
} else {
|
|
expiresHeader := respHeaders.Get("Expires")
|
|
if expiresHeader != "" {
|
|
expires, err := time.Parse(time.RFC1123, expiresHeader)
|
|
if err != nil {
|
|
lifetime = zeroDuration
|
|
} else {
|
|
lifetime = expires.Sub(date)
|
|
}
|
|
}
|
|
}
|
|
|
|
if maxAge, ok := reqCacheControl["max-age"]; ok {
|
|
// the client is willing to accept a response whose age is no greater than the specified time in seconds
|
|
lifetime, err = time.ParseDuration(maxAge + "s")
|
|
if err != nil {
|
|
lifetime = zeroDuration
|
|
}
|
|
}
|
|
if minfresh, ok := reqCacheControl["min-fresh"]; ok {
|
|
// the client wants a response that will still be fresh for at least the specified number of seconds.
|
|
minfreshDuration, err := time.ParseDuration(minfresh + "s")
|
|
if err == nil {
|
|
currentAge = time.Duration(currentAge + minfreshDuration)
|
|
}
|
|
}
|
|
|
|
if maxstale, ok := reqCacheControl["max-stale"]; ok {
|
|
// Indicates that the client is willing to accept a response that has exceeded its expiration time.
|
|
// If max-stale is assigned a value, then the client is willing to accept a response that has exceeded
|
|
// its expiration time by no more than the specified number of seconds.
|
|
// If no value is assigned to max-stale, then the client is willing to accept a stale response of any age.
|
|
//
|
|
// Responses served only because of a max-stale value are supposed to have a Warning header added to them,
|
|
// but that seems like a hassle, and is it actually useful? If so, then there needs to be a different
|
|
// return-value available here.
|
|
if maxstale == "" {
|
|
return fresh
|
|
}
|
|
maxstaleDuration, err := time.ParseDuration(maxstale + "s")
|
|
if err == nil {
|
|
currentAge = time.Duration(currentAge - maxstaleDuration)
|
|
}
|
|
}
|
|
|
|
if lifetime > currentAge {
|
|
return fresh
|
|
}
|
|
|
|
return stale
|
|
}
|
|
|
|
// Returns true if either the request or the response includes the stale-if-error
|
|
// cache control extension: https://tools.ietf.org/html/rfc5861
|
|
func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
|
|
respCacheControl := parseCacheControl(respHeaders)
|
|
reqCacheControl := parseCacheControl(reqHeaders)
|
|
|
|
var err error
|
|
lifetime := time.Duration(-1)
|
|
|
|
if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok {
|
|
if staleMaxAge != "" {
|
|
lifetime, err = time.ParseDuration(staleMaxAge + "s")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
} else {
|
|
return true
|
|
}
|
|
}
|
|
if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok {
|
|
if staleMaxAge != "" {
|
|
lifetime, err = time.ParseDuration(staleMaxAge + "s")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
} else {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if lifetime >= 0 {
|
|
date, err := Date(respHeaders)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
currentAge := clock.since(date)
|
|
if lifetime > currentAge {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func getEndToEndHeaders(respHeaders http.Header) []string {
|
|
// These headers are always hop-by-hop
|
|
hopByHopHeaders := map[string]struct{}{
|
|
"Connection": struct{}{},
|
|
"Keep-Alive": struct{}{},
|
|
"Proxy-Authenticate": struct{}{},
|
|
"Proxy-Authorization": struct{}{},
|
|
"Te": struct{}{},
|
|
"Trailers": struct{}{},
|
|
"Transfer-Encoding": struct{}{},
|
|
"Upgrade": struct{}{},
|
|
}
|
|
|
|
for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
|
|
// any header listed in connection, if present, is also considered hop-by-hop
|
|
if strings.Trim(extra, " ") != "" {
|
|
hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{}
|
|
}
|
|
}
|
|
endToEndHeaders := []string{}
|
|
for respHeader, _ := range respHeaders {
|
|
if _, ok := hopByHopHeaders[respHeader]; !ok {
|
|
endToEndHeaders = append(endToEndHeaders, respHeader)
|
|
}
|
|
}
|
|
return endToEndHeaders
|
|
}
|
|
|
|
func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
|
|
if _, ok := respCacheControl["no-store"]; ok {
|
|
return false
|
|
}
|
|
if _, ok := reqCacheControl["no-store"]; ok {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func newGatewayTimeoutResponse(req *http.Request) *http.Response {
|
|
var braw bytes.Buffer
|
|
braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
|
|
resp, err := http.ReadResponse(bufio.NewReader(&braw), req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return resp
|
|
}
|
|
|
|
// cloneRequest returns a clone of the provided *http.Request.
|
|
// The clone is a shallow copy of the struct and its Header map.
|
|
// (This function copyright goauth2 authors: https://code.google.com/p/goauth2)
|
|
func cloneRequest(r *http.Request) *http.Request {
|
|
// shallow copy of the struct
|
|
r2 := new(http.Request)
|
|
*r2 = *r
|
|
// deep copy of the Header
|
|
r2.Header = make(http.Header)
|
|
for k, s := range r.Header {
|
|
r2.Header[k] = s
|
|
}
|
|
return r2
|
|
}
|
|
|
|
type cacheControl map[string]string
|
|
|
|
func parseCacheControl(headers http.Header) cacheControl {
|
|
cc := cacheControl{}
|
|
ccHeader := headers.Get("Cache-Control")
|
|
for _, part := range strings.Split(ccHeader, ",") {
|
|
part = strings.Trim(part, " ")
|
|
if part == "" {
|
|
continue
|
|
}
|
|
if strings.ContainsRune(part, '=') {
|
|
keyval := strings.Split(part, "=")
|
|
cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",")
|
|
} else {
|
|
cc[part] = ""
|
|
}
|
|
}
|
|
return cc
|
|
}
|
|
|
|
// headerAllCommaSepValues returns all comma-separated values (each
|
|
// with whitespace trimmed) for header name in headers. According to
|
|
// Section 4.2 of the HTTP/1.1 spec
|
|
// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2),
|
|
// values from multiple occurrences of a header should be concatenated, if
|
|
// the header's value is a comma-separated list.
|
|
func headerAllCommaSepValues(headers http.Header, name string) []string {
|
|
var vals []string
|
|
for _, val := range headers[http.CanonicalHeaderKey(name)] {
|
|
fields := strings.Split(val, ",")
|
|
for i, f := range fields {
|
|
fields[i] = strings.TrimSpace(f)
|
|
}
|
|
vals = append(vals, fields...)
|
|
}
|
|
return vals
|
|
}
|
|
|
|
// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF
|
|
// handler with a full copy of the content read from R when EOF is
|
|
// reached.
|
|
type cachingReadCloser struct {
|
|
// Underlying ReadCloser.
|
|
R io.ReadCloser
|
|
// OnEOF is called with a copy of the content of R when EOF is reached.
|
|
OnEOF func(io.Reader)
|
|
|
|
buf bytes.Buffer // buf stores a copy of the content of R.
|
|
}
|
|
|
|
// Read reads the next len(p) bytes from R or until R is drained. The
|
|
// return value n is the number of bytes read. If R has no data to
|
|
// return, err is io.EOF and OnEOF is called with a full copy of what
|
|
// has been read so far.
|
|
func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
|
|
n, err = r.R.Read(p)
|
|
r.buf.Write(p[:n])
|
|
if err == io.EOF {
|
|
r.OnEOF(bytes.NewReader(r.buf.Bytes()))
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (r *cachingReadCloser) Close() error {
|
|
return r.R.Close()
|
|
}
|
|
|
|
// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
|
|
func NewMemoryCacheTransport() *Transport {
|
|
c := NewMemoryCache()
|
|
t := NewTransport(c)
|
|
return t
|
|
}
|