Skip to content

Commit

Permalink
Fix memory leak in Boto multipart uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
jnm authored and jschneier committed Aug 12, 2018
1 parent 7064f73 commit cd1d205
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 164 deletions.
21 changes: 8 additions & 13 deletions storages/backends/s3boto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
)
from django.utils.six import BytesIO

from storages.utils import check_location, clean_name, lookup_env, safe_join, setting
from storages.utils import (
check_location, clean_name, lookup_env, safe_join, setting,
)

try:
from boto import __version__ as boto_version
Expand Down Expand Up @@ -74,8 +76,6 @@ def __init__(self, name, mode, storage, buffer_size=None):
if buffer_size is not None:
self.buffer_size = buffer_size
self._write_counter = 0
# file position of the latest part file uploaded
self._last_part_pos = 0

@property
def size(self):
Expand Down Expand Up @@ -125,14 +125,10 @@ def write(self, content, *args, **kwargs):
reduced_redundancy=self._storage.reduced_redundancy,
encrypt_key=self._storage.encryption,
)
if self.buffer_size <= self._file_part_size:
if self.buffer_size <= self._buffer_file_size:
self._flush_write_buffer()
return super(S3BotoStorageFile, self).write(force_bytes(content), *args, **kwargs)

@property
def _file_part_size(self):
return self._buffer_file_size - self._last_part_pos

@property
def _buffer_file_size(self):
pos = self.file.tell()
Expand All @@ -142,15 +138,14 @@ def _buffer_file_size(self):
return length

def _flush_write_buffer(self):
if self._file_part_size:
if self._buffer_file_size:
self._write_counter += 1
pos = self.file.tell()
self.file.seek(self._last_part_pos)
self.file.seek(0)
headers = self._storage.headers.copy()
self._multipart.upload_part_from_file(
self.file, self._write_counter, headers=headers)
self.file.seek(pos)
self._last_part_pos = self._buffer_file_size
self.file.seek(0)
self.file.truncate()

def close(self):
if self._is_dirty:
Expand Down
15 changes: 4 additions & 11 deletions storages/backends/s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def __init__(self, name, mode, storage, buffer_size=None):
if buffer_size is not None:
self.buffer_size = buffer_size
self._write_counter = 0
# file position of the latest part file
self._last_part_pos = 0

@property
def size(self):
Expand Down Expand Up @@ -126,14 +124,10 @@ def write(self, content):
if self._storage.encryption:
parameters['ServerSideEncryption'] = 'AES256'
self._multipart = self.obj.initiate_multipart_upload(**parameters)
if self.buffer_size <= self._file_part_size:
if self.buffer_size <= self._buffer_file_size:
self._flush_write_buffer()
return super(S3Boto3StorageFile, self).write(force_bytes(content))

@property
def _file_part_size(self):
return self._buffer_file_size - self._last_part_pos

@property
def _buffer_file_size(self):
pos = self.file.tell()
Expand All @@ -148,12 +142,11 @@ def _flush_write_buffer(self):
"""
if self._buffer_file_size:
self._write_counter += 1
pos = self.file.tell()
self.file.seek(self._last_part_pos)
self.file.seek(0)
part = self._multipart.Part(self._write_counter)
part.upload(Body=self.file.read())
self.file.seek(pos)
self._last_part_pos = self._buffer_file_size
self.file.seek(0)
self.file.truncate()

def close(self):
if self._is_dirty:
Expand Down
96 changes: 59 additions & 37 deletions tests/test_s3boto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import mock

import datetime
import os

from boto.exception import S3ResponseError
from boto.s3.key import Key
Expand Down Expand Up @@ -171,6 +170,65 @@ def test_storage_open_write(self):
)
file._multipart.complete_upload.assert_called_once_with()

def test_storage_write_beyond_buffer_size(self):
"""
Test writing content that exceeds the buffer size
"""
name = 'test_open_for_writing_beyond_buffer_size.txt'

# Set the encryption flag used for multipart uploads
self.storage.encryption = True
# Set the ACL header used when creating/writing data.
self.storage.bucket.connection.provider.acl_header = 'x-amz-acl'
# Set the mocked key's bucket
self.storage.bucket.get_key.return_value.bucket = self.storage.bucket
# Set the name of the mock object
self.storage.bucket.get_key.return_value.name = name

file = self.storage.open(name, 'w')
self.storage.bucket.get_key.assert_called_with(name)

# Initiate the multipart upload
file.write('')
self.storage.bucket.initiate_multipart_upload.assert_called_with(
name,
headers={
'Content-Type': 'text/plain',
'x-amz-acl': 'public-read',
},
reduced_redundancy=self.storage.reduced_redundancy,
encrypt_key=True,
)

# Keep track of the content that would be uploaded to S3
def store_uploaded_part(fp, *args, **kwargs):
store_uploaded_part.content += fp.read().decode('utf-8')
store_uploaded_part.content = ''
file._multipart.upload_part_from_file.side_effect = store_uploaded_part

# Write content at least twice as long as the buffer size
written_content = ''
counter = 1
while len(written_content) < 2 * file.buffer_size:
content = 'hello, aws {counter}\n'.format(counter=counter)
# Write more than just a few bytes in each iteration to keep the
# test reasonably fast
content += '*' * int(file.buffer_size / 10)
file.write(content)
written_content += content
counter += 1

# Save the internal file before closing
_file = file.file
file.close()
file._multipart.upload_part_from_file.assert_has_calls([
mock.call(_file, 1, headers=self.storage.headers),
mock.call(_file, 2, headers=self.storage.headers),
])
file._multipart.complete_upload.assert_called_once_with()

self.assertEqual(store_uploaded_part.content, written_content)

def test_storage_exists_bucket(self):
self.storage._connection.get_bucket.side_effect = S3ResponseError(404, 'No bucket')
self.assertFalse(self.storage.exists(''))
Expand Down Expand Up @@ -309,42 +367,6 @@ def test_get_modified_time(self, getkey):
tz.make_naive(tz.make_aware(
datetime.datetime.strptime(utcnow, ISO8601), tz.utc)))

def test_file_greater_than_5MB(self):
name = 'test_storage_save.txt'
content = ContentFile('0' * 10 * 1024 * 1024)

# Set the encryption flag used for multipart uploads
self.storage.encryption = True
# Set the ACL header used when creating/writing data.
self.storage.bucket.connection.provider.acl_header = 'x-amz-acl'
# Set the mocked key's bucket
self.storage.bucket.get_key.return_value.bucket = self.storage.bucket
# Set the name of the mock object
self.storage.bucket.get_key.return_value.name = name

def get_upload_file_size(fp):
pos = fp.tell()
fp.seek(0, os.SEEK_END)
length = fp.tell() - pos
fp.seek(pos)
return length

def upload_part_from_file(fp, part_num, *args, **kwargs):
if len(file_part_size) != part_num:
file_part_size.append(get_upload_file_size(fp))

file_part_size = []
f = self.storage.open(name, 'w')

# initiate the multipart upload
f.write('')
f._multipart.upload_part_from_file = upload_part_from_file
for chunk in content.chunks():
f.write(chunk)
f.close()

assert content.size == sum(file_part_size)

def test_location_leading_slash(self):
msg = (
"S3BotoStorage.location cannot begin with a leading slash. "
Expand Down
165 changes: 62 additions & 103 deletions tests/test_s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,68 @@ def test_storage_open_write(self):
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [{'ETag': '123', 'PartNumber': 1}]})

def test_storage_write_beyond_buffer_size(self):
"""
Test writing content that exceeds the buffer size
"""
name = 'test_open_for_writïng_beyond_buffer_size.txt'

# Set the encryption flag used for multipart uploads
self.storage.encryption = True
self.storage.reduced_redundancy = True
self.storage.default_acl = 'public-read'

file = self.storage.open(name, 'w')
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name

# Initiate the multipart upload
file.write('')
obj.initiate_multipart_upload.assert_called_with(
ACL='public-read',
ContentType='text/plain',
ServerSideEncryption='AES256',
StorageClass='REDUCED_REDUNDANCY'
)
multipart = obj.initiate_multipart_upload.return_value

# Write content at least twice as long as the buffer size
written_content = ''
counter = 1
while len(written_content) < 2 * file.buffer_size:
content = 'hello, aws {counter}\n'.format(counter=counter)
# Write more than just a few bytes in each iteration to keep the
# test reasonably fast
content += '*' * int(file.buffer_size / 10)
file.write(content)
written_content += content
counter += 1

# Save the internal file before closing
multipart.parts.all.return_value = [
mock.MagicMock(e_tag='123', part_number=1),
mock.MagicMock(e_tag='456', part_number=2)
]
file.close()
self.assertListEqual(
multipart.Part.call_args_list,
[mock.call(1), mock.call(2)]
)
part = multipart.Part.return_value
uploaded_content = ''.join(
(args_list[1]['Body'].decode('utf-8')
for args_list in part.upload.call_args_list)
)
self.assertEqual(uploaded_content, written_content)
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [
{'ETag': '123', 'PartNumber': 1},
{'ETag': '456', 'PartNumber': 2},
]}
)

def test_auto_creating_bucket(self):
self.storage.auto_create_bucket = True
Bucket = mock.MagicMock()
Expand Down Expand Up @@ -446,109 +508,6 @@ def thread_storage_connection():
# Connection for each thread needs to be unique
self.assertIsNot(connections[0], connections[1])

def test_file_greater_than_5mb(self):
"""
test writing a large file in a single part so that the buffer is flushed
only on close
"""
name = 'test_storage_save.txt'
content = '0' * 10 * 1024 * 1024

# set the encryption flag used for multipart uploads
self.storage.encryption = True
self.storage.reduced_redundancy = True
self.storage.default_acl = 'public-read'

f = self.storage.open(name, 'w')
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# set the name of the mock object
obj.key = name
multipart = obj.initiate_multipart_upload.return_value
part = multipart.Part.return_value
multipart.parts.all.return_value = [mock.MagicMock(e_tag='123', part_number=1)]

with mock.patch.object(f, '_flush_write_buffer') as method:
f.write(content)
self.assertFalse(method.called) # buffer not flushed on write

assert f._file_part_size == len(content)
obj.initiate_multipart_upload.assert_called_with(
ACL='public-read',
ContentType='text/plain',
ServerSideEncryption='AES256',
StorageClass='REDUCED_REDUNDANCY'
)

with mock.patch.object(f, '_flush_write_buffer', wraps=f._flush_write_buffer) as method:
f.close()
method.assert_called_with() # buffer flushed on close
multipart.Part.assert_called_with(1)
part.upload.assert_called_with(Body=content.encode('utf-8'))
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [{'ETag': '123', 'PartNumber': 1}]})

def test_file_write_after_exceeding_5mb(self):
"""
test writing a large file in two parts so that the buffer is flushed
on write and on close
"""
name = 'test_storage_save.txt'
content1 = '0' * 5 * 1024 * 1024
content2 = '0'

# set the encryption flag used for multipart uploads
self.storage.encryption = True
self.storage.reduced_redundancy = True
self.storage.default_acl = 'public-read'

f = self.storage.open(name, 'w')
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# set the name of the mock object
obj.key = name
multipart = obj.initiate_multipart_upload.return_value
part = multipart.Part.return_value
multipart.parts.all.return_value = [
mock.MagicMock(e_tag='123', part_number=1),
mock.MagicMock(e_tag='456', part_number=2)
]

with mock.patch.object(f, '_flush_write_buffer', wraps=f._flush_write_buffer) as method:
f.write(content1)
self.assertFalse(method.called) # buffer doesn't get flushed on the first write
assert f._file_part_size == len(content1) # file part size is the size of what's written
assert f._last_part_pos == 0 # no parts added, so last part stays at 0
f.write(content2)
method.assert_called_with() # second write flushes buffer
multipart.Part.assert_called_with(1) # first part created
part.upload.assert_called_with(Body=content1.encode('utf-8')) # first part is uploaded
assert f._last_part_pos == len(content1) # buffer spools to end of content1
assert f._buffer_file_size == len(content1) + len(content2) # _buffer_file_size is total written
assert f._file_part_size == len(content2) # new part is size of content2

obj.initiate_multipart_upload.assert_called_with(
ACL='public-read',
ContentType='text/plain',
ServerSideEncryption='AES256',
StorageClass='REDUCED_REDUNDANCY'
)
# save the internal file before closing
f.close()
multipart.Part.assert_called_with(2)
part.upload.assert_called_with(Body=content2.encode('utf-8'))
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [
{
'ETag': '123',
'PartNumber': 1
},
{
'ETag': '456',
'PartNumber': 2
}
]})

def test_location_leading_slash(self):
msg = (
"S3Boto3Storage.location cannot begin with a leading slash. "
Expand Down

0 comments on commit cd1d205

Please sign in to comment.