diff --git a/.changes/next-release/bugfix-Credentials-23327.json b/.changes/next-release/bugfix-Credentials-23327.json new file mode 100644 index 0000000000..1353584a92 --- /dev/null +++ b/.changes/next-release/bugfix-Credentials-23327.json @@ -0,0 +1,5 @@ +{ + "category": "Credentials", + "type": "bugfix", + "description": "Fix a race condition related to assuming a role for the first time (`#1405 `__)" +} diff --git a/botocore/credentials.py b/botocore/credentials.py index 370c6d80aa..6d2a0d5949 100644 --- a/botocore/credentials.py +++ b/botocore/credentials.py @@ -444,6 +444,8 @@ def _protected_refresh(self, is_mandatory): # set of temporary credentials we have. return self._set_from_data(metadata) + self._frozen_credentials = ReadOnlyCredentials( + self._access_key, self._secret_key, self._token) if self._is_expired(): # We successfully refreshed credentials but for whatever # reason, our refreshing function returned credentials @@ -454,8 +456,6 @@ def _protected_refresh(self, is_mandatory): "refreshed credentials are still expired.") logger.warning(msg) raise RuntimeError(msg) - self._frozen_credentials = ReadOnlyCredentials( - self._access_key, self._secret_key, self._token) @staticmethod def _expiry_datetime(time_str): @@ -525,7 +525,7 @@ def __init__(self, refresh_using, method, time_fetcher=_local_now): self._frozen_credentials = None def refresh_needed(self, refresh_in=None): - if any(part is None for part in [self._access_key, self._secret_key]): + if self._frozen_credentials is None: return True return super(DeferredRefreshableCredentials, self).refresh_needed( refresh_in diff --git a/tests/functional/test_credentials.py b/tests/functional/test_credentials.py index ef9908d04a..8aea11ae51 100644 --- a/tests/functional/test_credentials.py +++ b/tests/functional/test_credentials.py @@ -31,6 +31,7 @@ from botocore.credentials import Credentials, ReadOnlyCredentials from botocore.credentials import AssumeRoleProvider from botocore.credentials import CanonicalNameCredentialSourcer +from botocore.credentials import DeferredRefreshableCredentials from botocore.session import Session from botocore.exceptions import InvalidConfigError, InfiniteLoopConfigError from botocore.stub import Stubber @@ -39,14 +40,7 @@ class TestCredentialRefreshRaces(unittest.TestCase): def assert_consistent_credentials_seen(self, creds, func): collected = [] - threads = [] - for _ in range(20): - threads.append(threading.Thread(target=func, args=(collected,))) - start = time.time() - for thread in threads: - thread.start() - for thread in threads: - thread.join() + self._run_threads(20, func, collected) for creds in collected: # During testing, the refresher uses it's current # refresh count as the values for the access, secret, and @@ -71,6 +65,21 @@ def assert_consistent_credentials_seen(self, creds, func): # first refresh ('1'). self.assertTrue(creds[0] == creds[1] == creds[2], creds) + def assert_non_none_retrieved_credentials(self, func): + collected = [] + self._run_threads(50, func, collected) + for cred in collected: + self.assertIsNotNone(cred) + + def _run_threads(self, num_threads, func, collected): + threads = [] + for _ in range(num_threads): + threads.append(threading.Thread(target=func, args=(collected,))) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + def test_has_no_race_conditions(self): creds = IntegerRefresher( creds_last_for=2, @@ -110,6 +119,26 @@ def _run_in_thread(collected): frozen.token)) self.assert_consistent_credentials_seen(creds, _run_in_thread) + def test_no_race_for_initial_refresh_of_deferred_refreshable(self): + def get_credentials(): + expiry_time = ( + datetime.now(tzlocal()) + timedelta(hours=24)).isoformat() + return { + 'access_key': 'my-access-key', + 'secret_key': 'my-secret-key', + 'token': 'my-token', + 'expiry_time': expiry_time + } + + deferred_creds = DeferredRefreshableCredentials( + get_credentials, 'fixed') + + def _run_in_thread(collected): + frozen = deferred_creds.get_frozen_credentials() + collected.append(frozen) + + self.assert_non_none_retrieved_credentials(_run_in_thread) + class TestAssumeRole(BaseEnvVar): def setUp(self):