diff --git a/smart_open/s3.py b/smart_open/s3.py index 60f3ef0e..7b6dfe91 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -28,6 +28,11 @@ from smart_open import constants +from typing import ( + Callable, + List, +) + logger = logging.getLogger(__name__) DEFAULT_MIN_PART_SIZE = 50 * 1024**2 @@ -47,13 +52,52 @@ 's3://my_key:my_secret@my_server:my_port@my_bucket/my_key', ) -_UPLOAD_ATTEMPTS = 6 -_SLEEP_SECONDS = 10 - # Returned by AWS when we try to seek beyond EOF. _OUT_OF_RANGE = 'InvalidRange' +class Retry: + def __init__(self): + self.attempts: int = 6 + self.sleep_seconds: int = 10 + self.exceptions: List[Exception] = [botocore.exceptions.EndpointConnectionError] + self.client_error_codes: List[str] = ['NoSuchUpload'] + + def _do(self, fn: Callable): + for attempt in range(self.attempts): + try: + return fn() + except tuple(self.exceptions) as err: + logger.critical( + 'Caught non-fatal %s, retrying %d more times', + err, + self.attempts - attempt - 1, + ) + logger.exception(err) + time.sleep(self.sleep_seconds) + except botocore.exceptions.ClientError as err: + error_code = err.response['Error'].get('Code') + if error_code not in self.client_error_codes: + raise + logger.critical( + 'Caught non-fatal ClientError (%s), retrying %d more times', + error_code, + self.attempts - attempt - 1, + ) + logger.exception(err) + time.sleep(self.sleep_seconds) + else: + logger.critical('encountered too many non-fatal errors, giving up') + raise IOError('%s failed after %d attempts', fn.func, self.attempts) + + +# +# The retry mechanism for this submodule. Client code may modify it, e.g. by +# updating RETRY.sleep_seconds and friends. +# +RETRY = Retry() + + class _ClientWrapper: """Wraps a client to inject the appropriate keyword args into each method call. @@ -803,7 +847,7 @@ def __init__( Bucket=bucket, Key=key, ) - self._upload_id = _retry_if_failed(partial)['UploadId'] + self._upload_id = RETRY._do(partial)['UploadId'] except botocore.client.ClientError as error: raise ValueError( 'the bucket %r does not exist, or is forbidden for access (%r)' % ( @@ -843,7 +887,7 @@ def close(self): UploadId=self._upload_id, MultipartUpload={'Parts': self._parts}, ) - _retry_if_failed(partial) + RETRY._do(partial) logger.debug('%s: completed multipart upload', self) elif self._upload_id: # @@ -954,7 +998,7 @@ def _upload_next_part(self): # of a temporary connection problem, so this part needs to be # especially robust. # - upload = _retry_if_failed( + upload = RETRY._do( functools.partial( self._client.upload_part, Bucket=self._bucket, @@ -1119,28 +1163,6 @@ def __repr__(self): return "smart_open.s3.SinglepartWriter(bucket=%r, key=%r)" % (self._bucket, self._key) -def _retry_if_failed( - partial, - attempts=_UPLOAD_ATTEMPTS, - sleep_seconds=_SLEEP_SECONDS, - exceptions=None): - if exceptions is None: - exceptions = (botocore.exceptions.EndpointConnectionError, ) - for attempt in range(attempts): - try: - return partial() - except exceptions: - logger.critical( - 'Unable to connect to the endpoint. Check your network connection. ' - 'Sleeping and retrying %d more times ' - 'before giving up.' % (attempts - attempt - 1) - ) - time.sleep(sleep_seconds) - else: - logger.critical('Unable to connect to the endpoint. Giving up.') - raise IOError('Unable to connect to the endpoint after %d attempts' % attempts) - - def _accept_all(key): return True diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index b2ab87be..d6e8c76a 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -21,6 +21,7 @@ import boto3 import botocore.client import botocore.endpoint +import botocore.exceptions import pytest # See https://github.com/piskvorky/smart_open/issues/800 @@ -952,19 +953,32 @@ def populate_bucket(num_keys=10): class RetryIfFailedTest(unittest.TestCase): + def setUp(self): + self.retry = smart_open.s3.Retry() + self.retry.attempts = 3 + self.retry.sleep_seconds = 0 + def test_success(self): partial = mock.Mock(return_value=1) - result = smart_open.s3._retry_if_failed(partial, attempts=3, sleep_seconds=0) + result = self.retry._do(partial) self.assertEqual(result, 1) self.assertEqual(partial.call_count, 1) - def test_failure(self): + def test_failure_exception(self): partial = mock.Mock(side_effect=ValueError) - exceptions = (ValueError, ) - + self.retry.exceptions = {ValueError: 'Let us retry ValueError'} with self.assertRaises(IOError): - smart_open.s3._retry_if_failed(partial, attempts=3, sleep_seconds=0, exceptions=exceptions) + self.retry._do(partial) + self.assertEqual(partial.call_count, 3) + def test_failure_client_error(self): + partial = mock.Mock( + side_effect=botocore.exceptions.ClientError( + {'Error': {'Code': 'NoSuchUpload'}}, 'NoSuchUpload' + ) + ) + with self.assertRaises(IOError): + self.retry._do(partial) self.assertEqual(partial.call_count, 3)