acme: Fix race condition in LocalStore during saving.

This commit is contained in:
Anton Popovichenko 2020-09-30 13:04:04 +03:00 committed by GitHub
parent ddc663eac0
commit ab13019bde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 2 deletions

View file

@ -34,7 +34,10 @@ func (s *LocalStore) save(resolverName string, storedData *StoredData) {
defer s.lock.Unlock() defer s.lock.Unlock()
s.storedData[resolverName] = storedData s.storedData[resolverName] = storedData
s.saveDataChan <- s.storedData
// we cannot pass s.storedData directly, map is reference type and as result
// we can face with race condition, so we need to work with objects copy
s.saveDataChan <- s.unSafeCopyOfStoredData()
} }
func (s *LocalStore) get(resolverName string) (*StoredData, error) { func (s *LocalStore) get(resolverName string) (*StoredData, error) {
@ -81,7 +84,10 @@ func (s *LocalStore) get(resolverName string) (*StoredData, error) {
} }
if len(certificates) < len(storedData.Certificates) { if len(certificates) < len(storedData.Certificates) {
storedData.Certificates = certificates storedData.Certificates = certificates
s.saveDataChan <- s.storedData
// we cannot pass s.storedData directly, map is reference type and as result
// we can face with race condition, so we need to work with objects copy
s.saveDataChan <- s.unSafeCopyOfStoredData()
} }
} }
} }
@ -111,6 +117,15 @@ func (s *LocalStore) listenSaveAction() {
}) })
} }
// unSafeCopyOfStoredData creates maps copy of storedData. Is not thread safe, you should use `s.lock`.
func (s *LocalStore) unSafeCopyOfStoredData() map[string]*StoredData {
result := map[string]*StoredData{}
for k, v := range s.storedData {
result[k] = v
}
return result
}
// GetAccount returns ACME Account. // GetAccount returns ACME Account.
func (s *LocalStore) GetAccount(resolverName string) (*Account, error) { func (s *LocalStore) GetAccount(resolverName string) (*Account, error) {
storedData, err := s.get(resolverName) storedData, err := s.get(resolverName)

View file

@ -0,0 +1,87 @@
package acme
import (
"fmt"
"io/ioutil"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLocalStore_GetAccount(t *testing.T) {
acmeFile := filepath.Join(t.TempDir(), "acme.json")
email := "some42@email.com"
filePayload := fmt.Sprintf(`{
"test": {
"Account": {
"Email": "%s"
}
}
}`, email)
err := ioutil.WriteFile(acmeFile, []byte(filePayload), 0o600)
require.NoError(t, err)
testCases := []struct {
desc string
filename string
expected *Account
}{
{
desc: "empty file",
filename: filepath.Join(t.TempDir(), "acme-empty.json"),
expected: nil,
},
{
desc: "file with data",
filename: acmeFile,
expected: &Account{Email: "some42@email.com"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
s := NewLocalStore(test.filename)
account, err := s.GetAccount("test")
require.NoError(t, err)
assert.Equal(t, test.expected, account)
})
}
}
func TestLocalStore_SaveAccount(t *testing.T) {
acmeFile := filepath.Join(t.TempDir(), "acme.json")
s := NewLocalStore(acmeFile)
email := "some@email.com"
err := s.SaveAccount("test", &Account{Email: email})
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
file, err := ioutil.ReadFile(acmeFile)
require.NoError(t, err)
expected := `{
"test": {
"Account": {
"Email": "some@email.com",
"Registration": null,
"PrivateKey": null,
"KeyType": ""
},
"Certificates": null
}
}`
assert.Equal(t, expected, string(file))
}