diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index da79a4f233b3e..64b852cbd7676 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1248,6 +1248,12 @@ type: string example: ~ default: "airflow.utils.email.send_email_smtp" + - name: email_conn_id + description: Email connection to use + version_added: ~ + type: string + example: ~ + default: "smtp_default" - name: default_email_on_retry description: | Whether email alerts should be sent when a task is retried diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 3c9adeba1c557..48d4111020ed0 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -620,6 +620,9 @@ session_lifetime_minutes = 43200 # Email backend to use email_backend = airflow.utils.email.send_email_smtp +# Email connection to use +email_conn_id = smtp_default + # Whether email alerts should be sent when a task is retried default_email_on_retry = True diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 767176d7fdade..8cc9305101245 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -80,6 +80,7 @@ page_size = 100 [email] email_backend = airflow.utils.email.send_email_smtp +email_conn_id = smtp_default [smtp] smtp_host = localhost diff --git a/airflow/operators/email.py b/airflow/operators/email.py index 4bccbc34c0d64..5ae5f8022504b 100644 --- a/airflow/operators/email.py +++ b/airflow/operators/email.py @@ -63,6 +63,7 @@ def __init__( # pylint: disable=invalid-name bcc: Optional[Union[List[str], str]] = None, mime_subtype: str = 'mixed', mime_charset: str = 'utf-8', + conn_id: Optional[str] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -74,6 +75,7 @@ def __init__( # pylint: disable=invalid-name self.bcc = bcc self.mime_subtype = mime_subtype self.mime_charset = mime_charset + self.conn_id = conn_id def execute(self, context): send_email( @@ -85,4 +87,5 @@ def execute(self, context): bcc=self.bcc, mime_subtype=self.mime_subtype, mime_charset=self.mime_charset, + conn_id=self.conn_id, ) diff --git a/airflow/providers/sendgrid/utils/emailer.py b/airflow/providers/sendgrid/utils/emailer.py index f95fd3c25aede..df832a4a2843e 100644 --- a/airflow/providers/sendgrid/utils/emailer.py +++ b/airflow/providers/sendgrid/utils/emailer.py @@ -21,6 +21,7 @@ import logging import mimetypes import os +import warnings from typing import Dict, Iterable, Optional, Union import sendgrid @@ -36,6 +37,8 @@ SandBoxMode, ) +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook from airflow.utils.email import get_email_address_list log = logging.getLogger(__name__) @@ -43,7 +46,7 @@ AddressesType = Union[str, Iterable[str]] -def send_email( +def send_email( # pylint: disable=too-many-locals to: AddressesType, subject: str, html_content: str, @@ -51,6 +54,7 @@ def send_email( cc: Optional[AddressesType] = None, bcc: Optional[AddressesType] = None, sandbox_mode: bool = False, + conn_id: str = "sendgrid_default", **kwargs, ) -> None: """ @@ -115,11 +119,25 @@ def send_email( ) mail.add_attachment(attachment) - _post_sendgrid_mail(mail.get()) - - -def _post_sendgrid_mail(mail_data: Dict) -> None: - sendgrid_client = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY')) + _post_sendgrid_mail(mail.get(), conn_id) + + +def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> None: + api_key = None + try: + conn = BaseHook.get_connection(conn_id) + api_key = conn.password + except AirflowException: + pass + if api_key is None: + warnings.warn( + "Fetching Sendgrid credentials from environment variables will be deprecated in a future " + "release. Please set credentials using a connection instead.", + PendingDeprecationWarning, + stacklevel=2, + ) + api_key = os.environ.get('SENDGRID_API_KEY') + sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key) response = sendgrid_client.client.mail.send.post(request_body=mail_data) # 2xx status code. if 200 <= response.status_code < 300: diff --git a/airflow/utils/email.py b/airflow/utils/email.py index 8e4359bd0ebf9..7d17027be4307 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -20,6 +20,7 @@ import logging import os import smtplib +import warnings from email.mime.application import MIMEApplication from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText @@ -27,7 +28,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException +from airflow.exceptions import AirflowConfigException, AirflowException log = logging.getLogger(__name__) @@ -36,16 +37,18 @@ def send_email( to: Union[List[str], Iterable[str]], subject: str, html_content: str, - files=None, - dryrun=False, - cc=None, - bcc=None, - mime_subtype='mixed', - mime_charset='utf-8', + files: Optional[List[str]] = None, + dryrun: bool = False, + cc: Optional[Union[str, Iterable[str]]] = None, + bcc: Optional[Union[str, Iterable[str]]] = None, + mime_subtype: str = 'mixed', + mime_charset: str = 'utf-8', + conn_id: Optional[str] = None, **kwargs, ): """Send email using backend specified in EMAIL_BACKEND.""" backend = conf.getimport('email', 'EMAIL_BACKEND') + backend_conn_id = conn_id or conf.get("email", "EMAIL_CONN_ID") to_list = get_email_address_list(to) to_comma_separated = ", ".join(to_list) @@ -59,6 +62,7 @@ def send_email( bcc=bcc, mime_subtype=mime_subtype, mime_charset=mime_charset, + conn_id=backend_conn_id, **kwargs, ) @@ -73,6 +77,7 @@ def send_email_smtp( bcc: Optional[Union[str, Iterable[str]]] = None, mime_subtype: str = 'mixed', mime_charset: str = 'utf-8', + conn_id: str = "smtp_default", **kwargs, ): """ @@ -94,7 +99,7 @@ def send_email_smtp( mime_charset=mime_charset, ) - send_mime_email(e_from=smtp_mail_from, e_to=recipients, mime_msg=msg, dryrun=dryrun) + send_mime_email(e_from=smtp_mail_from, e_to=recipients, mime_msg=msg, conn_id=conn_id, dryrun=dryrun) def build_mime_message( @@ -162,7 +167,9 @@ def build_mime_message( return msg, recipients -def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryrun: bool = False) -> None: +def send_mime_email( + e_from: str, e_to: List[str], mime_msg: MIMEMultipart, conn_id: str = "smtp_default", dryrun: bool = False +) -> None: """Send MIME email.""" smtp_host = conf.get('smtp', 'SMTP_HOST') smtp_port = conf.getint('smtp', 'SMTP_PORT') @@ -173,11 +180,28 @@ def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryru smtp_user = None smtp_password = None - try: - smtp_user = conf.get('smtp', 'SMTP_USER') - smtp_password = conf.get('smtp', 'SMTP_PASSWORD') - except AirflowConfigException: - log.debug("No user/password found for SMTP, so logging in with no authentication.") + smtp_user, smtp_password = None, None + if conn_id is not None: + try: + from airflow.hooks.base import BaseHook + + conn = BaseHook.get_connection(conn_id) + smtp_user = conn.login + smtp_password = conn.password + except AirflowException: + pass + if smtp_user is None or smtp_password is None: + warnings.warn( + "Fetching SMTP credentials from configuration variables will be deprecated in a future " + "release. Please set credentials using a connection instead.", + PendingDeprecationWarning, + stacklevel=2, + ) + try: + smtp_user = conf.get('smtp', 'SMTP_USER') + smtp_password = conf.get('smtp', 'SMTP_PASSWORD') + except AirflowConfigException: + log.debug("No user/password found for SMTP, so logging in with no authentication.") if not dryrun: for attempt in range(1, smtp_retry_limit + 1): diff --git a/tests/providers/sendgrid/utils/test_emailer.py b/tests/providers/sendgrid/utils/test_emailer.py index bb1a5f2cf5f20..cb6232c9bf4a0 100644 --- a/tests/providers/sendgrid/utils/test_emailer.py +++ b/tests/providers/sendgrid/utils/test_emailer.py @@ -95,7 +95,7 @@ def test_send_email_sendgrid_correct_email(self, mock_post): bcc=self.bcc, files=[f.name], ) - mock_post.assert_called_once_with(expected_mail_data) + mock_post.assert_called_once_with(expected_mail_data, "sendgrid_default") # Test the right email is constructed. @mock.patch.dict('os.environ', SENDGRID_MAIL_FROM='foo@bar.com', SENDGRID_MAIL_SENDER='Foo') @@ -110,7 +110,7 @@ def test_send_email_sendgrid_correct_email_extras(self, mock_post): personalization_custom_args=self.personalization_custom_args, categories=self.categories, ) - mock_post.assert_called_once_with(self.expected_mail_data_extras) + mock_post.assert_called_once_with(self.expected_mail_data_extras, "sendgrid_default") @mock.patch.dict('os.environ', clear=True) @mock.patch('airflow.providers.sendgrid.utils.emailer._post_sendgrid_mail') @@ -124,4 +124,4 @@ def test_send_email_sendgrid_sender(self, mock_post): from_email='foo@foo.bar', from_name='Foo Bar', ) - mock_post.assert_called_once_with(self.expected_mail_data_sender) + mock_post.assert_called_once_with(self.expected_mail_data_sender, "sendgrid_default") diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index a34dc7d845766..b680fdc46fdc5 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -76,6 +76,7 @@ def test_get_email_address_invalid_type_in_iterable(self): def setUp(self): conf.remove_option('email', 'EMAIL_BACKEND') + conf.remove_option('email', 'EMAIL_CONN_ID') @mock.patch('airflow.utils.email.send_email') def test_default_backend(self, mock_send_email): @@ -97,6 +98,7 @@ def test_custom_backend(self, mock_send_email): bcc=None, mime_charset='utf-8', mime_subtype='mixed', + conn_id='smtp_default', ) assert not mock_send_email.called @@ -192,6 +194,20 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value.sendmail.assert_called_once_with('from', 'to', msg.as_string()) assert mock_smtp.return_value.quit.called + @mock.patch('smtplib.SMTP') + @mock.patch('airflow.hooks.base.BaseHook') + def test_send_mime_conn_id(self, mock_hook, mock_smtp): + msg = MIMEMultipart() + mock_conn = mock.Mock() + mock_conn.login = 'user' + mock_conn.password = 'password' + mock_hook.get_connection.return_value = mock_conn + utils.email.send_mime_email('from', 'to', msg, dryrun=False, conn_id='smtp_default') + mock_hook.get_connection.assert_called_with('smtp_default') + mock_smtp.return_value.login.assert_called_once_with('user', 'password') + mock_smtp.return_value.sendmail.assert_called_once_with('from', 'to', msg.as_string()) + assert mock_smtp.return_value.quit.called + @mock.patch('smtplib.SMTP_SSL') @mock.patch('smtplib.SMTP') def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):