Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tdl 25859/handle s3 files race condition #67

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 0 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ jobs:
- run:
name: 'Integration Tests'
command: |
aws configure set aws_access_key_id "$AWS_ACCESS_KEY_ID"
aws configure set aws_secret_access_key "$AWS_SECRET_ACCESS_KEY"
aws s3 cp s3://com-stitchdata-dev-deployment-assets/environments/tap-tester/tap_tester_sandbox dev_env.sh
source dev_env.sh
source /usr/local/share/virtualenvs/tap-tester/bin/activate
Expand Down
10 changes: 7 additions & 3 deletions tap_s3_csv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import datetime
import json
import sys
import singer

from singer import metadata
from singer import utils as singer_utils
from tap_s3_csv.discover import discover_streams
from tap_s3_csv import s3
from tap_s3_csv.sync import sync_stream
Expand All @@ -27,7 +29,7 @@ def stream_is_selected(mdata):
return mdata.get((), {}).get('selected', False)


def do_sync(config, catalog, state):
def do_sync(config, catalog, state, sync_start_time):
LOGGER.info('Starting sync.')

for stream in catalog['streams']:
Expand All @@ -43,7 +45,7 @@ def do_sync(config, catalog, state):
singer.write_schema(stream_name, stream['schema'], key_properties)

LOGGER.info("%s: Starting sync", stream_name)
counter_value = sync_stream(config, state, table_spec, stream)
counter_value = sync_stream(config, state, table_spec, stream, sync_start_time)
LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter_value)

LOGGER.info('Done syncing.')
Expand Down Expand Up @@ -73,6 +75,8 @@ def main():
config = args.config

config['tables'] = validate_table_config(config)
now = datetime.now()
sync_start_time = singer_utils.strptime_with_tz(now.strftime("%Y-%m-%dT%H:%M:%SZ"))

try:
for page in s3.list_files_in_bucket(config):
Expand All @@ -84,7 +88,7 @@ def main():
if args.discover:
do_discover(args.config)
elif args.properties:
do_sync(config, args.properties, args.state)
do_sync(config, args.properties, args.state, sync_start_time)


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions tap_s3_csv/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

LOGGER = singer.get_logger()

def sync_stream(config, state, table_spec, stream):
def sync_stream(config, state, table_spec, stream, sync_start_time):
table_name = table_spec['table_name']
modified_since = singer_utils.strptime_with_tz(singer.get_bookmark(state, table_name, 'modified_since') or
config['start_date'])
Expand All @@ -40,8 +40,10 @@ def sync_stream(config, state, table_spec, stream):
for s3_file in sorted(s3_files, key=lambda item: item['last_modified']):
records_streamed += sync_table_file(
config, s3_file['key'], table_spec, stream)

state = singer.write_bookmark(state, table_name, 'modified_since', s3_file['last_modified'].isoformat())
if s3_file['last_modified'] < sync_start_time:
state = singer.write_bookmark(state, table_name, 'modified_since', s3_file['last_modified'].isoformat())
else:
state = singer.write_bookmark(state, table_name, 'modified_since', sync_start_time.isoformat())
Comment on lines +43 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unit test for this change.

singer.write_state(state)

if s3.skipped_files_count:
Expand Down
68 changes: 68 additions & 0 deletions tests/unittests/test_sync_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
from unittest.mock import patch, MagicMock
from datetime import datetime
from tap_s3_csv import sync_stream

class TestSyncStream(unittest.TestCase):

@patch('tap_s3_csv.s3.get_input_files_for_table')
@patch('tap_s3_csv.sync.sync_table_file')
@patch('tap_s3_csv.singer.get_bookmark')
@patch('tap_s3_csv.singer.write_bookmark')
@patch('tap_s3_csv.singer.write_state')
@patch('tap_s3_csv.LOGGER')
def test_sync_stream_with_files_older_than_sync_start_time(self, mock_logger, mock_write_state, mock_write_bookmark, mock_get_bookmark, mock_sync_table_file, mock_get_input_files_for_table):
"""
Tests the sync_stream function when the last_modified date of files is earlier than sync_start_time.
In this case, the bookmark is updated to the last_modified date of the file.
"""
mock_get_bookmark.return_value = '2024-01-01T00:00:00Z'
mock_get_input_files_for_table.return_value = [
{'key': 'file1.csv', 'last_modified': datetime(2024, 8, 13, 12, 0, 0)}
]
mock_sync_table_file.return_value = 1
mock_write_bookmark.return_value = '2024-08-13T12:00:00Z'

config = {'start_date': '2024-01-01T00:00:00Z'}
state = {}
table_spec = {'table_name': 'test_table'}
stream = None
sync_start_time = datetime(2024, 8, 14, 12, 0, 0)

records_streamed = sync_stream(config, state, table_spec, stream, sync_start_time)

self.assertEqual(records_streamed, 1)
# Verify that the bookmark was updated to the last_modified date of the file
mock_write_bookmark.assert_called_with(state, 'test_table', 'modified_since', '2024-08-13T12:00:00')
mock_write_state.assert_called_once()

@patch('tap_s3_csv.s3.get_input_files_for_table')
@patch('tap_s3_csv.sync.sync_table_file')
@patch('tap_s3_csv.singer.get_bookmark')
@patch('tap_s3_csv.singer.write_bookmark')
@patch('tap_s3_csv.singer.write_state')
@patch('tap_s3_csv.LOGGER')
def test_sync_stream_with_files_newer_than_sync_start_time(self, mock_logger, mock_write_state, mock_write_bookmark, mock_get_bookmark, mock_sync_table_file, mock_get_input_files_for_table):
"""
Tests the sync_stream function when the last_modified date of files is later than sync_start_time.
In this case, the bookmark is updated to sync_start_time.
"""
mock_get_bookmark.return_value = '2024-01-01T00:00:00Z'
mock_get_input_files_for_table.return_value = [
{'key': 'file1.csv', 'last_modified': datetime(2024, 8, 15, 12, 0, 0)}
]
mock_sync_table_file.return_value = 1
mock_write_bookmark.return_value = '2024-08-15T12:00:00Z'

config = {'start_date': '2024-01-01T00:00:00Z'}
state = {}
table_spec = {'table_name': 'test_table'}
stream = None
sync_start_time = datetime(2024, 8, 14, 12, 0, 0)

records_streamed = sync_stream(config, state, table_spec, stream, sync_start_time)

self.assertEqual(records_streamed, 1)
# Verify that the bookmark was updated to the sync_start_time
mock_write_bookmark.assert_called_with(state, 'test_table', 'modified_since', sync_start_time.isoformat())
mock_write_state.assert_called_once()