99 lines
2.2 KiB
Go
99 lines
2.2 KiB
Go
package ip
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
)
|
|
|
|
// Checker allows to check that addresses are in a trusted IPs
|
|
type Checker struct {
|
|
authorizedIPs []*net.IP
|
|
authorizedIPsNet []*net.IPNet
|
|
}
|
|
|
|
// NewChecker builds a new Checker given a list of CIDR-Strings to trusted IPs
|
|
func NewChecker(trustedIPs []string) (*Checker, error) {
|
|
if len(trustedIPs) == 0 {
|
|
return nil, errors.New("no trusted IPs provided")
|
|
}
|
|
|
|
checker := &Checker{}
|
|
|
|
for _, ipMask := range trustedIPs {
|
|
if ipAddr := net.ParseIP(ipMask); ipAddr != nil {
|
|
checker.authorizedIPs = append(checker.authorizedIPs, &ipAddr)
|
|
} else {
|
|
_, ipAddr, err := net.ParseCIDR(ipMask)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing CIDR trusted IPs %s: %v", ipAddr, err)
|
|
}
|
|
checker.authorizedIPsNet = append(checker.authorizedIPsNet, ipAddr)
|
|
}
|
|
}
|
|
|
|
return checker, nil
|
|
}
|
|
|
|
// IsAuthorized checks if provided request is authorized by the trusted IPs
|
|
func (ip *Checker) IsAuthorized(addr string) error {
|
|
var invalidMatches []string
|
|
|
|
host, _, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
host = addr
|
|
}
|
|
|
|
ok, err := ip.Contains(host)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !ok {
|
|
invalidMatches = append(invalidMatches, addr)
|
|
return fmt.Errorf("%q matched none of the trusted IPs", strings.Join(invalidMatches, ", "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Contains checks if provided address is in the trusted IPs
|
|
func (ip *Checker) Contains(addr string) (bool, error) {
|
|
if len(addr) <= 0 {
|
|
return false, errors.New("empty IP address")
|
|
}
|
|
|
|
ipAddr, err := parseIP(addr)
|
|
if err != nil {
|
|
return false, fmt.Errorf("unable to parse address: %s: %s", addr, err)
|
|
}
|
|
|
|
return ip.ContainsIP(ipAddr), nil
|
|
}
|
|
|
|
// ContainsIP checks if provided address is in the trusted IPs
|
|
func (ip *Checker) ContainsIP(addr net.IP) bool {
|
|
for _, authorizedIP := range ip.authorizedIPs {
|
|
if authorizedIP.Equal(addr) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
for _, authorizedNet := range ip.authorizedIPsNet {
|
|
if authorizedNet.Contains(addr) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func parseIP(addr string) (net.IP, error) {
|
|
userIP := net.ParseIP(addr)
|
|
if userIP == nil {
|
|
return nil, fmt.Errorf("can't parse IP from address %s", addr)
|
|
}
|
|
|
|
return userIP, nil
|
|
}
|