diff --git a/storage/google/cloud/storage/blob.py b/storage/google/cloud/storage/blob.py index 47e2571f770bc..21d92acd955ad 100644 --- a/storage/google/cloud/storage/blob.py +++ b/storage/google/cloud/storage/blob.py @@ -34,7 +34,11 @@ import time import warnings +from six.moves.urllib.parse import parse_qsl from six.moves.urllib.parse import quote +from six.moves.urllib.parse import urlencode +from six.moves.urllib.parse import urlsplit +from six.moves.urllib.parse import urlunsplit from google import resumable_media from google.resumable_media.requests import ChunkedDownload @@ -399,15 +403,19 @@ def _get_download_url(self): :rtype: str :returns: The download URL for the current blob. """ + name_value_pairs = [] if self.media_link is None: - download_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path) + base_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path) if self.generation is not None: - download_url += u'&generation={:d}'.format(self.generation) - if self.user_project is not None: - download_url += u'&userProject={}'.format(self.user_project) - return download_url + name_value_pairs.append( + ('generation', '{:d}'.format(self.generation))) else: - return self.media_link + base_url = self.media_link + + if self.user_project is not None: + name_value_pairs.append(('userProject', self.user_project)) + + return _add_query_parameters(base_url, name_value_pairs) def _do_download(self, transport, file_obj, download_url, headers): """Perform a download without any error handling. @@ -653,12 +661,14 @@ def _do_multipart_upload(self, client, stream, content_type, info = self._get_upload_arguments(content_type) headers, object_metadata, content_type = info - upload_url = _MULTIPART_URL_TEMPLATE.format( + base_url = _MULTIPART_URL_TEMPLATE.format( bucket_path=self.bucket.path) + name_value_pairs = [] if self.user_project is not None: - upload_url += '&userProject={}'.format(self.user_project) + name_value_pairs.append(('userProject', self.user_project)) + upload_url = _add_query_parameters(base_url, name_value_pairs) upload = MultipartUpload(upload_url, headers=headers) if num_retries is not None: @@ -729,12 +739,14 @@ def _initiate_resumable_upload(self, client, stream, content_type, if extra_headers is not None: headers.update(extra_headers) - upload_url = _RESUMABLE_URL_TEMPLATE.format( + base_url = _RESUMABLE_URL_TEMPLATE.format( bucket_path=self.bucket.path) + name_value_pairs = [] if self.user_project is not None: - upload_url += '&userProject={}'.format(self.user_project) + name_value_pairs.append(('userProject', self.user_project)) + upload_url = _add_query_parameters(base_url, name_value_pairs) upload = ResumableUpload(upload_url, chunk_size, headers=headers) if num_retries is not None: @@ -1670,3 +1682,24 @@ def _raise_from_invalid_response(error): to the failed status code """ raise exceptions.from_http_response(error.response) + + +def _add_query_parameters(base_url, name_value_pairs): + """Add one query parameter to a base URL. + + :type base_url: string + :param base_url: Base URL (may already contain query parameters) + + :type name_value_pairs: list of (string, string) tuples. + :param name_value_pairs: Names and values of the query parameters to add + + :rtype: string + :returns: URL with additional query strings appended. + """ + if len(name_value_pairs) == 0: + return base_url + + scheme, netloc, path, query, frag = urlsplit(base_url) + query = parse_qsl(query) + query.extend(name_value_pairs) + return urlunsplit((scheme, netloc, path, urlencode(query), frag)) diff --git a/storage/tests/unit/test_blob.py b/storage/tests/unit/test_blob.py index 627fe364c78eb..9ce326d818c41 100644 --- a/storage/tests/unit/test_blob.py +++ b/storage/tests/unit/test_blob.py @@ -366,7 +366,7 @@ def test__get_transport(self): def test__get_download_url_with_media_link(self): blob_name = 'something.txt' - bucket = mock.Mock(spec=[]) + bucket = _Bucket(name='IRRELEVANT') blob = self._make_one(blob_name, bucket=bucket) media_link = 'http://test.invalid' # Set the media link on the blob @@ -375,6 +375,19 @@ def test__get_download_url_with_media_link(self): download_url = blob._get_download_url() self.assertEqual(download_url, media_link) + def test__get_download_url_with_media_link_w_user_project(self): + blob_name = 'something.txt' + user_project = 'user-project-123' + bucket = _Bucket(name='IRRELEVANT', user_project=user_project) + blob = self._make_one(blob_name, bucket=bucket) + media_link = 'http://test.invalid' + # Set the media link on the blob + blob._properties['mediaLink'] = media_link + + download_url = blob._get_download_url() + self.assertEqual( + download_url, '{}?userProject={}'.format(media_link, user_project)) + def test__get_download_url_on_the_fly(self): blob_name = 'bzzz-fly.txt' bucket = _Bucket(name='buhkit') @@ -2417,6 +2430,37 @@ def test_default(self): self.assertEqual(exc_info.exception.errors, []) +class Test__add_query_parameters(unittest.TestCase): + + @staticmethod + def _call_fut(*args, **kwargs): + from google.cloud.storage.blob import _add_query_parameters + + return _add_query_parameters(*args, **kwargs) + + def test_w_empty_list(self): + BASE_URL = 'https://test.example.com/base' + self.assertEqual(self._call_fut(BASE_URL, []), BASE_URL) + + def test_wo_existing_qs(self): + BASE_URL = 'https://test.example.com/base' + NV_LIST = [('one', 'One'), ('two', 'Two')] + expected = '&'.join([ + '{}={}'.format(name, value) for name, value in NV_LIST]) + self.assertEqual( + self._call_fut(BASE_URL, NV_LIST), + '{}?{}'.format(BASE_URL, expected)) + + def test_w_existing_qs(self): + BASE_URL = 'https://test.example.com/base?one=Three' + NV_LIST = [('one', 'One'), ('two', 'Two')] + expected = '&'.join([ + '{}={}'.format(name, value) for name, value in NV_LIST]) + self.assertEqual( + self._call_fut(BASE_URL, NV_LIST), + '{}&{}'.format(BASE_URL, expected)) + + class _Connection(object): API_BASE_URL = 'http://example.com'