diff --git a/CHANGES.md b/CHANGES.md index db93d5f2..b868e57d 100755 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,8 @@ # Changelog +## 1.3.0 +- `sqs_sensor` supports multiaccount integration. + ## 1.2.3 - Support Python 3 everywhere diff --git a/aws.yaml.example b/aws.yaml.example index 44d46431..406b45cf 100755 --- a/aws.yaml.example +++ b/aws.yaml.example @@ -11,8 +11,17 @@ service_notifications_sensor: path: /my-path sqs_sensor: + roles: + - arn:aws:iam::123456789098:role/rolename1 + - arn:aws:iam::901234567812:role/rolename2 + - arn:aws:iam::567890123489:role/rolename3 input_queues: - - queuename + - queue_name_1 + - https://sqs.us-east-1.amazonaws.com/567890123489/queue_name_2 + - https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_3 + - https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_4 + - https://sqs.eu-west-1.amazonaws.com/901234567812/queue_name_5 + - https://sqs.sa-east-1.amazonaws.com/567890123489/queue_name_6 sqs_other: max_number_of_messages: 1 diff --git a/config.schema.yaml b/config.schema.yaml index 835a9387..e6ed811e 100755 --- a/config.schema.yaml +++ b/config.schema.yaml @@ -41,10 +41,15 @@ type: object properties: input_queues: - description: "Names of queue to fetch messages from Amazon SQS" + description: "Names or URLs of queues to fetch messages from Amazon SQS" type: "array" items: type: "string" + roles: + type: "array" + description: "ARNs of the roles which allow queues to be fetched for messages" + items: + type: "string" sqs_other: type: object properties: diff --git a/pack.yaml b/pack.yaml index aa0b4da7..329e6a22 100755 --- a/pack.yaml +++ b/pack.yaml @@ -19,7 +19,7 @@ keywords: - SQS - lambda - kinesis -version : 1.2.3 +version : 1.3.0 author : StackStorm, Inc. email : info@stackstorm.com python_versions: diff --git a/sensors/sqs_sensor.py b/sensors/sqs_sensor.py index 77b5fcb0..3076626b 100755 --- a/sensors/sqs_sensor.py +++ b/sensors/sqs_sensor.py @@ -38,11 +38,14 @@ import six import json +import sys + from boto3.session import Session from botocore.exceptions import ClientError from botocore.exceptions import NoRegionError from botocore.exceptions import NoCredentialsError from botocore.exceptions import EndpointConnectionError +from collections import defaultdict from st2reactor.sensor.base import PollingSensor @@ -55,19 +58,36 @@ def __init__(self, sensor_service, config=None, poll_interval=5): def setup(self): self._logger = self._sensor_service.get_logger(name=self.__class__.__name__) - self.session = None - self.sqs_res = None + self.account_id = None + self.credentials = {} + self.sessions = {} + self.cross_roles = {} + self.sqs_res = defaultdict(dict) def poll(self): # setting SQS ServiceResource object from the parameter of datastore or configuration file self._may_setup_sqs() for queue in self.input_queues: - msgs = self._receive_messages(queue=self._get_queue_by_name(queue), - num_messages=self.max_number_of_messages) + account_id, region = self._get_info(queue) + + while True: + try: + msgs = self._receive_messages(queue=self._get_queue(queue, account_id, region), + num_messages=self.max_number_of_messages) + except ClientError as e: + if e.response['Error']['Code'] == 'ExpiredToken': + self._setup_multiaccount_session(account_id) + continue + raise + break + for msg in msgs: if msg: - payload = {"queue": queue, "body": json.loads(msg.body)} + payload = {"queue": queue, + "account_id": account_id, + "region": region, + "body": json.loads(msg.body)} self._sensor_service.dispatch(trigger="aws.sqs_new_message", payload=payload) msg.delete() @@ -89,7 +109,7 @@ def _get_config_entry(self, key, prefix=None): ''' Get configuration values either from Datastore or config file. ''' config = self.config if prefix: - config = self._config.get(prefix, {}) + config = self.config.get(prefix, {}) value = self._sensor_service.get_value('aws.%s' % (key), local=False) if not value: @@ -101,61 +121,156 @@ def _get_config_entry(self, key, prefix=None): return value def _may_setup_sqs(self): - queues = self._get_config_entry(key='input_queues', prefix='sqs_sensor') + self.access_key_id = self._get_config_entry('aws_access_key_id') + self.secret_access_key = self._get_config_entry('aws_secret_access_key') + self.aws_region = self._get_config_entry('region') + self.max_number_of_messages = self._get_config_entry('max_number_of_messages', + prefix='sqs_other') + + if not self.account_id: + self._setup_session() + queues = self._get_config_entry(key='input_queues', prefix='sqs_sensor') # XXX: This is a hack as from datastore we can only receive a string while # from config.yaml we can receive a list if isinstance(queues, six.string_types): - self.input_queues = [x.strip() for x in queues.split(',')] + self.input_queues = [six.moves.urllib.parse.urlparse(x.strip()) for x in + queues.split(',')] elif isinstance(queues, list): - self.input_queues = queues + self.input_queues = [six.moves.urllib.parse.urlparse(queue) for queue in queues] else: self.input_queues = [] - self.aws_access_key = self._get_config_entry('aws_access_key_id') - self.aws_secret_key = self._get_config_entry('aws_secret_access_key') - self.aws_region = self._get_config_entry('region') - - self.max_number_of_messages = self._get_config_entry('max_number_of_messages', - prefix='sqs_other') - # checker configuration is update, or not - def _is_same_credentials(): - c = self.session.get_credentials() + def _is_same_credentials(session, account_id): + if not session: + return False + + c = session.get_credentials() return c is not None and \ - c.access_key == self.aws_access_key and \ - c.secret_key == self.aws_secret_key and \ - self.session.region_name == self.aws_region + c.access_key == self.credentials[account_id][0] and \ + c.secret_key == self.credentials[account_id][1] and \ + (account_id == self.account_id or c.token == self.credentials[account_id][2]) + + # Build a map between 'account_id' and its 'role arn' by parsing the matching config entry + self.cross_roles = { + arn.split(':')[4]: arn + for arn in self._get_config_entry('roles', 'sqs_sensor') or [] + } + required_accounts = {self._get_info(queue)[0] for queue in self.input_queues} - if self.session is None or not _is_same_credentials(): - self._setup_sqs() + for account_id in required_accounts: + if account_id != self.account_id and account_id not in self.cross_roles: + continue - def _setup_sqs(self): - ''' Setup Boto3 structures ''' - self._logger.debug('Setting up SQS resources') - self.session = Session(aws_access_key_id=self.aws_access_key, - aws_secret_access_key=self.aws_secret_key, - region_name=self.aws_region) + session = self.sessions.get(account_id) + if not _is_same_credentials(session, account_id): + if account_id == self.account_id: + self._setup_session() + else: + self._setup_multiaccount_session(account_id) + + def _setup_session(self): + ''' Setup Boto3 session ''' + session = Session(aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key) + + if not self.account_id: + self.account_id = session.client('sts').get_caller_identity().get('Account') + self.credentials[self.account_id] = (self.access_key_id, self.secret_access_key, None) + + self.sessions[self.account_id] = session + self.sqs_res.pop(self.account_id, None) + + def _setup_multiaccount_session(self, account_id): + ''' Assume role and setup session for the cross-account capability''' + try: + assumed_role = self.sessions[self.account_id].client('sts').assume_role( + RoleArn=self.cross_roles[account_id], + RoleSessionName='StackStormEvents' + ) + except ClientError: + self._logger.error('Could not assume role on %s', account_id) + return + + self.credentials[account_id] = (assumed_role["Credentials"]["AccessKeyId"], + assumed_role["Credentials"]["SecretAccessKey"], + assumed_role["Credentials"]["SessionToken"]) + + session = Session( + aws_access_key_id=self.credentials[account_id][0], + aws_secret_access_key=self.credentials[account_id][1], + aws_session_token=self.credentials[account_id][2] + ) + self.sessions[account_id] = session + self.sqs_res.pop(account_id, None) + + def _setup_sqs(self, session, account_id, region): + ''' Setup SQS resources''' + if region in self.sqs_res[account_id]: + return self.sqs_res[account_id][region] try: - self.sqs_res = self.session.resource('sqs') + self.sqs_res[account_id][region] = session.resource('sqs', region_name=region) + return self.sqs_res[account_id][region] except NoRegionError: - self._logger.warning("The specified region '%s' is invalid", self.aws_region) + self._logger.error("The specified region '%s' for account %s is invalid.", + region, account_id) + + def _get_info(self, queue): + ''' Retrieve the account ID and region from the queue URL ''' + # Pull default values from previous configuration + account_id = self.account_id + aws_region = self.aws_region + + # Netloc will be empty if the queue is just a name + if queue.netloc: + try: + account_id = queue.path.split('/')[1] + except IndexError as e: + six.reraise(type(e), type(e)( + "Queue URL must contain the account ID as the first part of the path, " + "eg: https://sqs..amazonaws.com//"), + sys.exc_info()[2]) + else: + self._logger.debug("Using %s as account_id", account_id) + + try: + aws_region = queue.netloc.split('.')[1] + except IndexError as e: + six.reraise(type(e), type(e)( + "Queue URL must contain the AWS region, " + "eg: https://sqs..amazonaws.com/..."), + sys.exc_info()[2]) + else: + self._logger.debug("Using %s as the AWS region", aws_region) + + return account_id, aws_region + + def _get_queue(self, queue, account_id, region): + ''' Fetch QUEUE by its name or URL and create new one if queue doesn't exist ''' + if account_id not in self.sessions: + self._logger.error('Session for account id %s does not exist', account_id) + return None + + sqs_res = self._setup_sqs(self.sessions[account_id], account_id, region) + if sqs_res is None: + return None - def _get_queue_by_name(self, queueName): - ''' Fetch QUEUE by it's name create new one if queue doesn't exist ''' try: - return self.sqs_res.get_queue_by_name(QueueName=queueName) + if queue.netloc: + return sqs_res.Queue(six.moves.urllib.parse.urlunparse(queue)) + return sqs_res.get_queue_by_name(QueueName=queue.path.split('/')[-1]) except ClientError as e: if e.response['Error']['Code'] == 'AWS.SimpleQueueService.NonExistentQueue': - self._logger.warning("SQS Queue: %s doesn't exist, creating it.", queueName) - return self.sqs_res.create_queue(QueueName=queueName) + self._logger.warning("SQS Queue: %s doesn't exist, creating it.", queue) + return sqs_res.create_queue(QueueName=queue.path.split('/')[-1]) elif e.response['Error']['Code'] == 'InvalidClientTokenId': - self._logger.warning("Cloudn't operate sqs because of invalid credential config") + self._logger.warning("Couldn't operate sqs because of invalid credential config") else: raise except NoCredentialsError: - self._logger.warning("Cloudn't operate sqs because of invalid credential config") + self._logger.warning("Couldn't operate sqs because of invalid credential config") except EndpointConnectionError as e: self._logger.warning(e) diff --git a/sensors/sqs_sensor.yaml b/sensors/sqs_sensor.yaml index 2bb4ca34..243186d5 100755 --- a/sensors/sqs_sensor.yaml +++ b/sensors/sqs_sensor.yaml @@ -12,5 +12,9 @@ properties: queue: type: "string" + account_id: + type: "string" + region: + type: "string" body: type: "object" diff --git a/tests/fixtures/full.yaml b/tests/fixtures/full.yaml index 8b8547e2..ffbaefde 100755 --- a/tests/fixtures/full.yaml +++ b/tests/fixtures/full.yaml @@ -10,7 +10,8 @@ service_notifications_sensor: path: "/my-path" sqs_sensor: - input_queues: "input_queue" + input_queues: + - "input_queue" sqs_other: max_number_of_messages: 1 diff --git a/tests/fixtures/mixed.yaml b/tests/fixtures/mixed.yaml new file mode 100644 index 00000000..38910aa0 --- /dev/null +++ b/tests/fixtures/mixed.yaml @@ -0,0 +1,20 @@ +--- +aws_access_key_id: "access_key_id" +aws_secret_access_key: "secret_key" +region: "us-west-1" +st2_user_data: "" + +service_notifications_sensor: + host: "localhost" + port: 12345 + path: "/my-path" + +sqs_sensor: + roles: + - "arn:aws:iam::345678901223:role/rolename1" + input_queues: + - "input_queue" + - "https://sqs.us-east-1.amazonaws.com/345678901223/queue_name_2" + +sqs_other: + max_number_of_messages: 1 diff --git a/tests/fixtures/multiaccount.yaml b/tests/fixtures/multiaccount.yaml new file mode 100644 index 00000000..08b0de25 --- /dev/null +++ b/tests/fixtures/multiaccount.yaml @@ -0,0 +1,19 @@ +--- +aws_access_key_id: "access_key_id" +aws_secret_access_key: "secret_key" +region: "us-west-1" +st2_user_data: "" + +service_notifications_sensor: + host: "localhost" + port: 12345 + path: "/my-path" + +sqs_sensor: + roles: + - "arn:aws:iam::345678901223:role/rolename1" + input_queues: + - "https://sqs.us-east-1.amazonaws.com/345678901223/queue_name_2" + +sqs_other: + max_number_of_messages: 1 diff --git a/tests/test_sensor_sqs.py b/tests/test_sensor_sqs.py index 1e62f61c..d997a4f2 100644 --- a/tests/test_sensor_sqs.py +++ b/tests/test_sensor_sqs.py @@ -1,9 +1,11 @@ import mock +import six import yaml from boto3.session import Session from botocore.exceptions import ClientError from botocore.exceptions import NoCredentialsError +from botocore.exceptions import NoRegionError from botocore.exceptions import EndpointConnectionError from st2tests.base import BaseSensorTestCase @@ -20,6 +22,24 @@ def __init__(self, msgs=[]): def get_queue_by_name(self, **kwargs): return SQSSensorTestCase.MockQueue(self.msgs) + def Queue(self, queue): + return SQSSensorTestCase.MockQueue(self.msgs) + + class MockResourceNonExistentQueue(object): + def __init__(self, msgs=[]): + self.msgs = msgs + + def get_queue_by_name(self, **kwargs): + raise ClientError({'Error': {'Code': 'AWS.SimpleQueueService.NonExistentQueue'}}, + 'sqs_test') + + def Queue(self, queue): + raise ClientError({'Error': {'Code': 'AWS.SimpleQueueService.NonExistentQueue'}}, + 'sqs_test') + + def create_queue(self, **kwargs): + return SQSSensorTestCase.MockQueue(self.msgs) + class MockResourceRaiseClientError(object): def __init__(self, error_code=''): self.error_code = error_code @@ -27,14 +47,45 @@ def __init__(self, error_code=''): def get_queue_by_name(self, **kwargs): raise ClientError({'Error': {'Code': self.error_code}}, 'sqs_test') + def Queue(self, queue): + raise ClientError({'Error': {'Code': self.error_code}}, 'sqs_test') + class MockResourceRaiseNoCredentialsError(object): def get_queue_by_name(self, **kwargs): raise NoCredentialsError() + def Queue(self, queue): + raise NoCredentialsError() + class MockResourceRaiseEndpointConnectionError(object): def get_queue_by_name(self, **kwargs): raise EndpointConnectionError(endpoint_url='') + def Queue(self, queue): + raise EndpointConnectionError(endpoint_url='') + + class MockStsClient(object): + def __init__(self): + self.meta = mock.Mock(service_model={}) + + def get_caller_identity(self): + ci = mock.Mock() + ci.get = lambda attribute: '111222333444' if attribute == 'Account' else None + return ci + + def assume_role(self, RoleArn, RoleSessionName): + return { + 'Credentials': { + 'AccessKeyId': 'access_key_id_example', + 'SecretAccessKey': 'secret_access_key_example', + 'SessionToken': 'session_token_example' + } + } + + class MockStsClientRaiseClientError(MockStsClient): + def assume_role(self, RoleArn, RoleSessionName): + raise ClientError({'Error': {'Code': 'AccessDenied'}}, 'sqs_test') + class MockQueue(object): def __init__(self, msgs=[]): self.dummy_messages = [SQSSensorTestCase.MockMessage(x) for x in msgs] @@ -54,10 +105,13 @@ def setUp(self): self.full_config = self.load_yaml('full.yaml') self.blank_config = self.load_yaml('blank.yaml') + self.multiaccount_config = self.load_yaml('multiaccount.yaml') + self.mixed_config = self.load_yaml('mixed.yaml') def load_yaml(self, filename): return yaml.safe_load(self.get_fixture_content(filename)) + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) def test_poll_with_blank_config(self): sensor = self.get_sensor_instance(config=self.blank_config) @@ -66,18 +120,29 @@ def test_poll_with_blank_config(self): self.assertEqual(self.get_dispatched_triggers(), []) + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResource())) - def test_poll_without_message(self): - sensor = self.get_sensor_instance(config=self.full_config) + def _poll_without_message(self, config): + sensor = self.get_sensor_instance(config=config) sensor.setup() sensor.poll() self.assertEqual(self.get_dispatched_triggers(), []) + def test_poll_without_message_full_config(self): + self._poll_without_message(self.full_config) + + def test_poll_without_message_multiaccount_config(self): + self._poll_without_message(self.multiaccount_config) + + def test_poll_without_message_mixed_config(self): + self._poll_without_message(self.mixed_config) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) - def test_poll_with_message(self): - sensor = self.get_sensor_instance(config=self.full_config) + def _poll_with_message(self, config): + sensor = self.get_sensor_instance(config=config) sensor.setup() sensor.poll() @@ -85,9 +150,39 @@ def test_poll_with_message(self): self.assertTriggerDispatched(trigger='aws.sqs_new_message') self.assertNotEqual(self.get_dispatched_triggers(), []) - @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) + def test_poll_with_message_full_config(self): + self._poll_with_message(self.full_config) + + def test_poll_with_message_multiaccount_config(self): + self._poll_with_message(self.multiaccount_config) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) + @mock.patch.object(Session, 'resource', + mock.Mock(return_value=MockResourceNonExistentQueue(['{"foo":"bar"}']))) + def _poll_with_non_existent_queue(self, config): + sensor = self.get_sensor_instance(config=config) + + sensor.setup() + sensor.poll() + + contexts = self.get_dispatched_triggers() + self.assertNotEqual(contexts, []) + self.assertTriggerDispatched(trigger='aws.sqs_new_message') + + def test_poll_with_non_existent_queue_full_config(self): + self._poll_with_non_existent_queue(self.full_config) + + def test_poll_with_non_existent_queue_multiaccount_config(self): + self._poll_with_non_existent_queue(self.multiaccount_config) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) + @mock.patch.object(Session, 'resource', + mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) def test_set_input_queues_config_dynamically(self): sensor = self.get_sensor_instance(config=self.blank_config) + sensor._sensor_service.set_value('aws.roles', + ['arn:aws:iam::123456789098:role/rolename1'], + local=False) sensor.setup() # set credential mock to prevent sending request to AWS @@ -104,18 +199,36 @@ def test_set_input_queues_config_dynamically(self): sensor._sensor_service.set_value('aws.input_queues', 'fuga,puyo', local=False) sensor.poll() + # update input_queues to check this is reflected + sensor._sensor_service.set_value( + 'aws.input_queues', + 'https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_3', + local=False + ) + sensor.poll() + contexts = self.get_dispatched_triggers() self.assertNotEqual(contexts, []) self.assertTriggerDispatched(trigger='aws.sqs_new_message') # get message from queue 'hoge', 'fuga' then 'puyo' - self.assertEqual([x['payload']['queue'] for x in contexts], ['hoge', 'fuga', 'puyo']) + self.assertEqual([x['payload']['queue'] for x in contexts], + [six.moves.urllib.parse.urlparse(queue) for queue in + ['hoge', 'fuga', 'puyo', + 'https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_3']]) - @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) + @mock.patch.object(Session, 'resource', + mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) def test_set_input_queues_config_with_list(self): # set 'input_queues' config with list type config = self.full_config - config['sqs_sensor']['input_queues'] = ['foo', 'bar'] + config['sqs_sensor']['input_queues'] = [ + 'foo', + 'bar', + 'https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_3' + ] + config['sqs_sensor']['roles'] = ['arn:aws:iam::123456789098:role/rolename1'] sensor = self.get_sensor_instance(config=config) sensor.setup() @@ -124,34 +237,119 @@ def test_set_input_queues_config_with_list(self): contexts = self.get_dispatched_triggers() self.assertNotEqual(contexts, []) self.assertTriggerDispatched(trigger='aws.sqs_new_message') - self.assertEqual([x['payload']['queue'] for x in contexts], ['foo', 'bar']) + self.assertEqual([x['payload']['queue'] for x in contexts], + [six.moves.urllib.parse.urlparse(queue) for queue in + ['foo', 'bar', + 'https://sqs.us-west-2.amazonaws.com/123456789098/queue_name_3']]) + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) @mock.patch.object(Session, 'resource', - mock.Mock(return_value=MockResourceRaiseClientError('InvalidClientTokenId'))) - def test_fails_with_invalid_token(self): - sensor = self.get_sensor_instance(config=self.full_config) + mock.Mock( + return_value=MockResourceRaiseClientError('InvalidClientTokenId')) + ) + def _fails_with_invalid_token(self, config): + sensor = self.get_sensor_instance(config=config) sensor.setup() sensor.poll() self.assertEqual(self.get_dispatched_triggers(), []) + def test_fails_with_invalid_token_full_config(self): + self._fails_with_invalid_token(self.full_config) + + def test_fails_with_invalid_token_multiaccount_config(self): + self._fails_with_invalid_token(self.multiaccount_config) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResourceRaiseNoCredentialsError())) - def test_fails_without_credentials(self): - sensor = self.get_sensor_instance(config=self.full_config) + def _fails_without_credentials(self, config): + sensor = self.get_sensor_instance(config=config) sensor.setup() sensor.poll() self.assertEqual(self.get_dispatched_triggers(), []) + def test_fails_without_credentials_full_config(self): + self._fails_without_credentials(self.full_config) + + def test_fails_without_credentials_multiaccount_config(self): + self._fails_without_credentials(self.multiaccount_config) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) @mock.patch.object(Session, 'resource', mock.Mock(return_value=MockResourceRaiseEndpointConnectionError())) - def test_fails_with_invalid_region(self): - sensor = self.get_sensor_instance(config=self.full_config) + def _fails_with_invalid_region(self, config): + sensor = self.get_sensor_instance(config=config) sensor.setup() sensor.poll() self.assertEqual(self.get_dispatched_triggers(), []) + + def test_fails_with_invalid_region_full_config(self): + self._fails_with_invalid_region(self.full_config) + + def test_fails_with_invalid_region_multiaccount_config(self): + self._fails_with_invalid_region(self.multiaccount_config) + + @mock.patch.object(Session, 'client', + mock.Mock(return_value=MockStsClientRaiseClientError())) + @mock.patch.object(Session, 'resource', + mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) + def _fails_assuming_role(self, config): + sensor = self.get_sensor_instance(config=config) + + sensor.setup() + sensor.poll() + + def test_fails_assuming_role_full_config(self): + self._fails_assuming_role(self.full_config) + + self.assertTriggerDispatched(trigger='aws.sqs_new_message') + self.assertNotEqual(self.get_dispatched_triggers(), []) + + def test_fails_assuming_role_multiaccount_config(self): + self._fails_assuming_role(self.multiaccount_config) + self.assertEqual(self.get_dispatched_triggers(), []) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) + @mock.patch.object(Session, 'resource', + mock.Mock(side_effect=NoRegionError( + service_name='sqs', region_name='us-east-1'))) + def test_fails_creating_sqs_resource(self): + sensor = self.get_sensor_instance(config=self.mixed_config) + + sensor.setup() + sensor.poll() + + self.assertEqual(self.get_dispatched_triggers(), []) + + @mock.patch.object(Session, 'client', mock.Mock(return_value=MockStsClient())) + @mock.patch.object(Session, 'resource', + mock.Mock(return_value=MockResource(['{"foo":"bar"}']))) + def _poll_with_missing_arn(self, config): + config['sqs_sensor']['roles'] = [] + + sensor = self.get_sensor_instance(config=config) + sensor.setup() + sensor.poll() + + def test_poll_with_missing_arn_full_config(self): + self._poll_with_missing_arn(self.full_config) + + self.assertNotEqual(self.get_dispatched_triggers(), []) + self.assertTriggerDispatched(trigger='aws.sqs_new_message') + + def test_poll_with_missing_arn_multiaccount_config(self): + self._poll_with_missing_arn(self.multiaccount_config) + + self.assertEqual(self.get_dispatched_triggers(), []) + + def test_poll_with_missing_arn_mixed_config(self): + self._poll_with_missing_arn(self.mixed_config) + + self.assertNotEqual(self.get_dispatched_triggers(), []) + self.assertTriggerDispatched(trigger='aws.sqs_new_message')