Merge branch 'v1.5' into master

This commit is contained in:
Fernandez Ludovic 2018-03-09 12:02:29 +01:00
commit 0a41cd43a5
4 changed files with 70 additions and 21 deletions

View file

@ -372,7 +372,7 @@ func getBoolValue(i ecsInstance, labelName string, defaultValue bool) bool {
rawValue, ok := i.containerDefinition.DockerLabels[labelName] rawValue, ok := i.containerDefinition.DockerLabels[labelName]
if ok { if ok {
if rawValue != nil { if rawValue != nil {
v, err := strconv.ParseBool(*rawValue) v, err := strconv.ParseBool(aws.StringValue(rawValue))
if err == nil { if err == nil {
return v return v
} }

View file

@ -280,12 +280,12 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
byTaskDefinition := make(map[string]int) byTaskDefinition := make(map[string]int)
for _, task := range tasks { for _, task := range tasks {
if _, found := byContainerInstance[*task.ContainerInstanceArn]; !found { if _, found := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)]; !found {
byContainerInstance[*task.ContainerInstanceArn] = len(containerInstanceArns) byContainerInstance[aws.StringValue(task.ContainerInstanceArn)] = len(containerInstanceArns)
containerInstanceArns = append(containerInstanceArns, task.ContainerInstanceArn) containerInstanceArns = append(containerInstanceArns, task.ContainerInstanceArn)
} }
if _, found := byTaskDefinition[*task.TaskDefinitionArn]; !found { if _, found := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)]; !found {
byTaskDefinition[*task.TaskDefinitionArn] = len(taskDefinitionArns) byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)] = len(taskDefinitionArns)
taskDefinitionArns = append(taskDefinitionArns, task.TaskDefinitionArn) taskDefinitionArns = append(taskDefinitionArns, task.TaskDefinitionArn)
} }
} }
@ -302,23 +302,23 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
for _, task := range tasks { for _, task := range tasks {
machineIdx := byContainerInstance[*task.ContainerInstanceArn] machineIdx := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)]
taskDefIdx := byTaskDefinition[*task.TaskDefinitionArn] taskDefIdx := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)]
for _, container := range task.Containers { for _, container := range task.Containers {
taskDefinition := taskDefinitions[taskDefIdx] taskDefinition := taskDefinitions[taskDefIdx]
var containerDefinition *ecs.ContainerDefinition var containerDefinition *ecs.ContainerDefinition
for _, def := range taskDefinition.ContainerDefinitions { for _, def := range taskDefinition.ContainerDefinitions {
if *container.Name == *def.Name { if aws.StringValue(container.Name) == aws.StringValue(def.Name) {
containerDefinition = def containerDefinition = def
break break
} }
} }
instances = append(instances, ecsInstance{ instances = append(instances, ecsInstance{
fmt.Sprintf("%s-%s", strings.Replace(*task.Group, ":", "-", 1), *container.Name), fmt.Sprintf("%s-%s", strings.Replace(aws.StringValue(task.Group), ":", "-", 1), *container.Name),
(*task.TaskArn)[len(*task.TaskArn)-12:], (aws.StringValue(task.TaskArn))[len(aws.StringValue(task.TaskArn))-12:],
task, task,
taskDefinition, taskDefinition,
container, container,
@ -338,7 +338,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
instanceIds := make([]*string, len(containerArns)) instanceIds := make([]*string, len(containerArns))
instances := make([]*ec2.Instance, len(containerArns)) instances := make([]*ec2.Instance, len(containerArns))
for i, arn := range containerArns { for i, arn := range containerArns {
order[*arn] = i order[aws.StringValue(arn)] = i
} }
req, _ := client.ecs.DescribeContainerInstancesRequest(&ecs.DescribeContainerInstancesInput{ req, _ := client.ecs.DescribeContainerInstancesRequest(&ecs.DescribeContainerInstancesInput{
@ -353,7 +353,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
containerResp := req.Data.(*ecs.DescribeContainerInstancesOutput) containerResp := req.Data.(*ecs.DescribeContainerInstancesOutput)
for i, container := range containerResp.ContainerInstances { for i, container := range containerResp.ContainerInstances {
order[*container.Ec2InstanceId] = order[*container.ContainerInstanceArn] order[aws.StringValue(container.Ec2InstanceId)] = order[aws.StringValue(container.ContainerInstanceArn)]
instanceIds[i] = container.Ec2InstanceId instanceIds[i] = container.Ec2InstanceId
} }
} }
@ -371,7 +371,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
for _, r := range instancesResp.Reservations { for _, r := range instancesResp.Reservations {
for _, i := range r.Instances { for _, i := range r.Instances {
if i.InstanceId != nil { if i.InstanceId != nil {
instances[order[*i.InstanceId]] = i instances[order[aws.StringValue(i.InstanceId)]] = i
} }
} }
} }
@ -408,8 +408,8 @@ func (p *Provider) filterInstance(i ecsInstance) bool {
return false return false
} }
if *i.machine.State.Name != ec2.InstanceStateNameRunning { if aws.StringValue(i.machine.State.Name) != ec2.InstanceStateNameRunning {
log.Debugf("Filtering ecs instance in an incorrect state %s (%s) (state = %s)", i.Name, i.ID, *i.machine.State.Name) log.Debugf("Filtering ecs instance in an incorrect state %s (%s) (state = %s)", i.Name, i.ID, aws.StringValue(i.machine.State.Name))
return false return false
} }

View file

@ -17,10 +17,10 @@ type IP struct {
// NewIP builds a new IP given a list of CIDR-Strings to whitelist // NewIP builds a new IP given a list of CIDR-Strings to whitelist
func NewIP(whitelistStrings []string, insecure bool) (*IP, error) { func NewIP(whitelistStrings []string, insecure bool) (*IP, error) {
if len(whitelistStrings) == 0 && !insecure { if len(whitelistStrings) == 0 && !insecure {
return nil, errors.New("no whiteListsNet provided") return nil, errors.New("no white list provided")
} }
ip := IP{} ip := IP{insecure: insecure}
if !insecure { if !insecure {
for _, whitelistString := range whitelistStrings { for _, whitelistString := range whitelistStrings {

View file

@ -19,12 +19,12 @@ func TestNew(t *testing.T) {
desc: "nil whitelist", desc: "nil whitelist",
whitelistStrings: nil, whitelistStrings: nil,
expectedWhitelists: nil, expectedWhitelists: nil,
errMessage: "no whiteListsNet provided", errMessage: "no white list provided",
}, { }, {
desc: "empty whitelist", desc: "empty whitelist",
whitelistStrings: []string{}, whitelistStrings: []string{},
expectedWhitelists: nil, expectedWhitelists: nil,
errMessage: "no whiteListsNet provided", errMessage: "no white list provided",
}, { }, {
desc: "whitelist containing empty string", desc: "whitelist containing empty string",
whitelistStrings: []string{ whitelistStrings: []string{
@ -90,7 +90,7 @@ func TestNew(t *testing.T) {
} }
} }
func TestIsAllowed(t *testing.T) { func TestContainsIsAllowed(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string
whitelistStrings []string whitelistStrings []string
@ -275,6 +275,7 @@ func TestIsAllowed(t *testing.T) {
test := test test := test
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
t.Parallel() t.Parallel()
whiteLister, err := NewIP(test.whitelistStrings, false) whiteLister, err := NewIP(test.whitelistStrings, false)
require.NoError(t, err) require.NoError(t, err)
@ -297,7 +298,55 @@ func TestIsAllowed(t *testing.T) {
} }
} }
func TestBrokenIPs(t *testing.T) { func TestContainsInsecure(t *testing.T) {
mustNewIP := func(whitelistStrings []string, insecure bool) *IP {
ip, err := NewIP(whitelistStrings, insecure)
if err != nil {
t.Fatal(err)
}
return ip
}
testCases := []struct {
desc string
whiteLister *IP
ip string
expected bool
}{
{
desc: "valid ip and insecure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true),
ip: "1.2.3.1",
expected: true,
},
{
desc: "invalid ip and insecure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true),
ip: "10.2.3.1",
expected: true,
},
{
desc: "invalid ip and secure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, false),
ip: "10.2.3.1",
expected: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ok, _, err := test.whiteLister.Contains(test.ip)
require.NoError(t, err)
assert.Equal(t, test.expected, ok)
})
}
}
func TestContainsBrokenIPs(t *testing.T) {
brokenIPs := []string{ brokenIPs := []string{
"foo", "foo",
"10.0.0.350", "10.0.0.350",