diff --git a/airflow/contrib/example_dags/example_gcp_compute.py b/airflow/contrib/example_dags/example_gcp_compute.py new file mode 100644 index 0000000000000..e4abe2e152a00 --- /dev/null +++ b/airflow/contrib/example_dags/example_gcp_compute.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG that starts, stops and sets the machine type of a Google Compute +Engine instance. + +This DAG relies on the following Airflow variables +https://airflow.apache.org/concepts.html#variables +* PROJECT_ID - Google Cloud Platform project where the Compute Engine instance exists. +* LOCATION - Google Cloud Platform zone where the instance exists. +* INSTANCE - Name of the Compute Engine instance. +* SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g. 'n1-standard-1'. + See https://cloud.google.com/compute/docs/machine-types +""" + +import datetime + +import airflow +from airflow import models +from airflow.contrib.operators.gcp_compute_operator import GceInstanceStartOperator, \ + GceInstanceStopOperator, GceSetMachineTypeOperator + +# [START howto_operator_gce_args] +PROJECT_ID = models.Variable.get('PROJECT_ID', '') +LOCATION = models.Variable.get('LOCATION', '') +INSTANCE = models.Variable.get('INSTANCE', '') +SHORT_MACHINE_TYPE_NAME = models.Variable.get('SHORT_MACHINE_TYPE_NAME', '') +SET_MACHINE_TYPE_BODY = { + 'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION, SHORT_MACHINE_TYPE_NAME) +} + +default_args = { + 'start_date': airflow.utils.dates.days_ago(1) +} +# [END howto_operator_gce_args] + +with models.DAG( + 'example_gcp_compute', + default_args=default_args, + schedule_interval=datetime.timedelta(days=1) +) as dag: + # [START howto_operator_gce_start] + gce_instance_start = GceInstanceStartOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + task_id='gcp_compute_start_task' + ) + # [END howto_operator_gce_start] + # Duplicate start for idempotence testing + gce_instance_start2 = GceInstanceStartOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + task_id='gcp_compute_start_task2' + ) + # [START howto_operator_gce_stop] + gce_instance_stop = GceInstanceStopOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + task_id='gcp_compute_stop_task' + ) + # [END howto_operator_gce_stop] + # Duplicate stop for idempotence testing + gce_instance_stop2 = GceInstanceStopOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + task_id='gcp_compute_stop_task2' + ) + # [START howto_operator_gce_set_machine_type] + gce_set_machine_type = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + body=SET_MACHINE_TYPE_BODY, + task_id='gcp_compute_set_machine_type' + ) + # [END howto_operator_gce_set_machine_type] + # Duplicate set machine type for idempotence testing + gce_set_machine_type2 = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE, + body=SET_MACHINE_TYPE_BODY, + task_id='gcp_compute_set_machine_type2' + ) + + gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> \ + gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2 diff --git a/airflow/contrib/hooks/gcp_compute_hook.py b/airflow/contrib/hooks/gcp_compute_hook.py new file mode 100644 index 0000000000000..5fa088942b706 --- /dev/null +++ b/airflow/contrib/hooks/gcp_compute_hook.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from googleapiclient.discovery import build + +from airflow import AirflowException +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook + +# Number of retries - used by googleapiclient method calls to perform retries +# For requests that are "retriable" +NUM_RETRIES = 5 + +# Time to sleep between active checks of the operation results +TIME_TO_SLEEP_IN_SECONDS = 1 + + +class GceOperationStatus: + PENDING = "PENDING" + RUNNING = "RUNNING" + DONE = "DONE" + + +# noinspection PyAbstractClass +class GceHook(GoogleCloudBaseHook): + """ + Hook for Google Compute Engine APIs. + """ + _conn = None + + def __init__(self, + api_version, + gcp_conn_id='google_cloud_default', + delegate_to=None): + super(GceHook, self).__init__(gcp_conn_id, delegate_to) + self.api_version = api_version + + def get_conn(self): + """ + Retrieves connection to Google Compute Engine. + + :return: Google Compute Engine services object + :rtype: dict + """ + if not self._conn: + http_authorized = self._authorize() + self._conn = build('compute', self.api_version, + http=http_authorized, cache_discovery=False) + return self._conn + + def start_instance(self, project_id, zone, resource_id): + """ + Starts an existing instance defined by project_id, zone and resource_id. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :return: True if the operation succeeded, raises an error otherwise + :rtype: bool + """ + response = self.get_conn().instances().start( + project=project_id, + zone=zone, + instance=resource_id + ).execute(num_retries=NUM_RETRIES) + operation_name = response["name"] + return self._wait_for_operation_to_complete(project_id, zone, operation_name) + + def stop_instance(self, project_id, zone, resource_id): + """ + Stops an instance defined by project_id, zone and resource_id. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :return: True if the operation succeeded, raises an error otherwise + :rtype: bool + """ + response = self.get_conn().instances().stop( + project=project_id, + zone=zone, + instance=resource_id + ).execute(num_retries=NUM_RETRIES) + operation_name = response["name"] + return self._wait_for_operation_to_complete(project_id, zone, operation_name) + + def set_machine_type(self, project_id, zone, resource_id, body): + """ + Sets machine type of an instance defined by project_id, zone and resource_id. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param body: Body required by the Compute Engine setMachineType API, + as described in + https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType + :type body: dict + :return: True if the operation succeeded, raises an error otherwise + :rtype: bool + """ + response = self._execute_set_machine_type(project_id, zone, resource_id, body) + operation_name = response["name"] + return self._wait_for_operation_to_complete(project_id, zone, operation_name) + + def _execute_set_machine_type(self, project_id, zone, resource_id, body): + return self.get_conn().instances().setMachineType( + project=project_id, zone=zone, instance=resource_id, body=body)\ + .execute(num_retries=NUM_RETRIES) + + def _wait_for_operation_to_complete(self, project_id, zone, operation_name): + """ + Waits for the named operation to complete - checks status of the + asynchronous call. + + :param operation_name: name of the operation + :type operation_name: str + :return: True if the operation succeeded, raises an error otherwise + :rtype: bool + """ + service = self.get_conn() + while True: + operation_response = self._check_operation_status( + service, operation_name, project_id, zone) + if operation_response.get("status") == GceOperationStatus.DONE: + error = operation_response.get("error") + if error: + code = operation_response.get("httpErrorStatusCode") + msg = operation_response.get("httpErrorMessage") + # Extracting the errors list as string and trimming square braces + error_msg = str(error.get("errors"))[1:-1] + raise AirflowException("{} {}: ".format(code, msg) + error_msg) + # No meaningful info to return from the response in case of success + return True + time.sleep(TIME_TO_SLEEP_IN_SECONDS) + + def _check_operation_status(self, service, operation_name, project_id, zone): + return service.zoneOperations().get( + project=project_id, zone=zone, operation=operation_name).execute( + num_retries=NUM_RETRIES) diff --git a/airflow/contrib/operators/gcp_compute_operator.py b/airflow/contrib/operators/gcp_compute_operator.py new file mode 100644 index 0000000000000..a2fd54529429c --- /dev/null +++ b/airflow/contrib/operators/gcp_compute_operator.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow import AirflowException +from airflow.contrib.hooks.gcp_compute_hook import GceHook +from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class GceBaseOperator(BaseOperator): + """ + Abstract base operator for Google Compute Engine operators to inherit from. + """ + @apply_defaults + def __init__(self, + project_id, + zone, + resource_id, + gcp_conn_id='google_cloud_default', + api_version='v1', + *args, **kwargs): + self.project_id = project_id + self.zone = zone + self.full_location = 'projects/{}/zones/{}'.format(self.project_id, + self.zone) + self.resource_id = resource_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self._validate_inputs() + self._hook = GceHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version) + super(GceBaseOperator, self).__init__(*args, **kwargs) + + def _validate_inputs(self): + if not self.project_id: + raise AirflowException("The required parameter 'project_id' is missing") + if not self.zone: + raise AirflowException("The required parameter 'zone' is missing") + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing") + + def execute(self, context): + pass + + +class GceInstanceStartOperator(GceBaseOperator): + """ + Start an instance in Google Compute Engine. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + """ + template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version') + + @apply_defaults + def __init__(self, + project_id, + zone, + resource_id, + gcp_conn_id='google_cloud_default', + api_version='v1', + *args, **kwargs): + super(GceInstanceStartOperator, self).__init__( + project_id=project_id, zone=zone, resource_id=resource_id, + gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) + + def execute(self, context): + return self._hook.start_instance(self.project_id, self.zone, self.resource_id) + + +class GceInstanceStopOperator(GceBaseOperator): + """ + Stop an instance in Google Compute Engine. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + """ + template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version') + + @apply_defaults + def __init__(self, + project_id, + zone, + resource_id, + gcp_conn_id='google_cloud_default', + api_version='v1', + *args, **kwargs): + super(GceInstanceStopOperator, self).__init__( + project_id=project_id, zone=zone, resource_id=resource_id, + gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) + + def execute(self, context): + return self._hook.stop_instance(self.project_id, self.zone, self.resource_id) + + +SET_MACHINE_TYPE_VALIDATION_SPECIFICATION = [ + dict(name="machineType", regexp="^.+$"), +] + + +class GceSetMachineTypeOperator(GceBaseOperator): + """ + Changes the machine type for a stopped instance to the machine type specified in + the request. + + :param project_id: Google Cloud Platform project where the Compute Engine + instance exists. + :type project_id: str + :param zone: Google Cloud Platform zone where the instance exists. + :type zone: str + :param resource_id: Name of the Compute Engine instance resource. + :type resource_id: str + :param body: Body required by the Compute Engine setMachineType API, as described in + https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType#request-body + :type body: dict + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + :param api_version: API version used (e.g. v1). + :type api_version: str + """ + template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version') + + @apply_defaults + def __init__(self, + project_id, + zone, + resource_id, + body, + gcp_conn_id='google_cloud_default', + api_version='v1', + validate_body=True, + *args, **kwargs): + self.body = body + self._field_validator = None + if validate_body: + self._field_validator = GcpBodyFieldValidator( + SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version) + super(GceSetMachineTypeOperator, self).__init__( + project_id=project_id, zone=zone, resource_id=resource_id, + gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs) + + def _validate_all_body_fields(self): + if self._field_validator: + self._field_validator.validate(self.body) + + def execute(self, context): + self._validate_all_body_fields() + return self._hook.set_machine_type(self.project_id, self.zone, + self.resource_id, self.body) diff --git a/airflow/contrib/operators/gcp_function_operator.py b/airflow/contrib/operators/gcp_function_operator.py index 4455307c93259..8207b9d084f89 100644 --- a/airflow/contrib/operators/gcp_function_operator.py +++ b/airflow/contrib/operators/gcp_function_operator.py @@ -20,277 +20,23 @@ from googleapiclient.errors import HttpError -from airflow import AirflowException, LoggingMixin +from airflow import AirflowException +from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator, \ + GcpFieldValidationException from airflow.version import version from airflow.models import BaseOperator from airflow.contrib.hooks.gcp_function_hook import GcfHook from airflow.utils.decorators import apply_defaults -# TODO: This whole section should be extracted later to contrib/tools/field_validator.py - -COMPOSITE_FIELD_TYPES = ['union', 'dict'] - - -class FieldValidationException(AirflowException): - """ - Thrown when validation finds dictionary field not valid according to specification. - """ - - def __init__(self, message): - super(FieldValidationException, self).__init__(message) - - -class ValidationSpecificationException(AirflowException): - """ - Thrown when validation specification is wrong - (rather than dictionary being validated). - This should only happen during development as ideally - specification itself should not be invalid ;) . - """ - - def __init__(self, message): - super(ValidationSpecificationException, self).__init__(message) - - -# TODO: make better description, add some examples -# TODO: move to contrib/utils folder when we reuse it. -class BodyFieldValidator(LoggingMixin): - """ - Validates correctness of request body according to specification. - The specification can describe various type of - fields including custom validation, and union of fields. This validator is meant - to be reusable by various operators - in the near future, but for now it is left as part of the Google Cloud Function, - so documentation about the - validator is not yet complete. To see what kind of specification can be used, - please take a look at - gcp_function_operator.CLOUD_FUNCTION_VALIDATION which specifies validation - for GCF deploy operator. - - :param validation_specs: dictionary describing validation specification - :type validation_specs: [dict] - :param api_version: Version of the api used (for example v1) - :type api_version: str - - """ - def __init__(self, validation_specs, api_version): - # type: ([dict], str) -> None - super(BodyFieldValidator, self).__init__() - self._validation_specs = validation_specs - self._api_version = api_version - - @staticmethod - def _get_field_name_with_parent(field_name, parent): - if parent: - return parent + '.' + field_name - return field_name - - @staticmethod - def _sanity_checks(children_validation_specs, field_type, full_field_path, - regexp, custom_validation, value): - # type: (dict, str, str, str, function, object) -> None - if value is None and field_type != 'union': - raise FieldValidationException( - "The required body field '{}' is missing. Please add it.". - format(full_field_path)) - if regexp and field_type: - raise ValidationSpecificationException( - "The validation specification entry '{}' has both type and regexp. " - "The regexp is only allowed without type (i.e. assume type is 'str' " - "that can be validated with regexp)".format(full_field_path)) - if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES: - raise ValidationSpecificationException( - "Nested fields are specified in field '{}' of type '{}'. " - "Nested fields are only allowed for fields of those types: ('{}').". - format(full_field_path, field_type, COMPOSITE_FIELD_TYPES)) - if custom_validation and field_type: - raise ValidationSpecificationException( - "The validation specification field '{}' has both type and " - "custom_validation. Custom validation is only allowed without type.". - format(full_field_path)) - - @staticmethod - def _validate_regexp(full_field_path, regexp, value): - # type: (str, str, str) -> None - if not re.match(regexp, value): - # Note matching of only the beginning as we assume the regexps all-or-nothing - raise FieldValidationException( - "The body field '{}' of value '{}' does not match the field " - "specification regexp: '{}'.". - format(full_field_path, value, regexp)) - - def _validate_dict(self, children_validation_specs, full_field_path, value): - # type: (dict, str, dict) -> None - for child_validation_spec in children_validation_specs: - self._validate_field(validation_spec=child_validation_spec, - dictionary_to_validate=value, - parent=full_field_path) - for field_name in value.keys(): - if field_name not in [spec['name'] for spec in children_validation_specs]: - self.log.warning( - "The field '{}' is in the body, but is not specified in the " - "validation specification '{}'. " - "This might be because you are using newer API version and " - "new field names defined for that version. Then the warning " - "can be safely ignored, or you might want to upgrade the operator" - "to the version that supports the new API version.".format( - self._get_field_name_with_parent(field_name, full_field_path), - children_validation_specs)) - - def _validate_union(self, children_validation_specs, full_field_path, - dictionary_to_validate): - # type: (dict, str, dict) -> None - field_found = False - found_field_name = None - for child_validation_spec in children_validation_specs: - # Forcing optional so that we do not have to type optional = True - # in specification for all union fields - new_field_found = self._validate_field( - validation_spec=child_validation_spec, - dictionary_to_validate=dictionary_to_validate, - parent=full_field_path, - force_optional=True) - field_name = child_validation_spec['name'] - if new_field_found and field_found: - raise FieldValidationException( - "The mutually exclusive fields '{}' and '{}' belonging to the " - "union '{}' are both present. Please remove one". - format(field_name, found_field_name, full_field_path)) - if new_field_found: - field_found = True - found_field_name = field_name - if not field_found: - self.log.warning( - "There is no '{}' union defined in the body {}. " - "Validation expected one of '{}' but could not find any. It's possible " - "that you are using newer API version and there is another union variant " - "defined for that version. Then the warning can be safely ignored, " - "or you might want to upgrade the operator to the version that " - "supports the new API version.".format( - full_field_path, - dictionary_to_validate, - [field['name'] for field in children_validation_specs])) - - def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, - force_optional=False): - """ - Validates if field is OK. - :param validation_spec: specification of the field - :type validation_spec: dict - :param dictionary_to_validate: dictionary where the field should be present - :type dictionary_to_validate: dict - :param parent: full path of parent field - :type parent: str - :param force_optional: forces the field to be optional - (all union fields have force_optional set to True) - :type force_optional: bool - :return: True if the field is present - """ - field_name = validation_spec['name'] - field_type = validation_spec.get('type') - optional = validation_spec.get('optional') - regexp = validation_spec.get('regexp') - children_validation_specs = validation_spec.get('fields') - required_api_version = validation_spec.get('api_version') - custom_validation = validation_spec.get('custom_validation') - - full_field_path = self._get_field_name_with_parent(field_name=field_name, - parent=parent) - if required_api_version and required_api_version != self._api_version: - self.log.debug( - "Skipping validation of the field '{}' for API version '{}' " - "as it is only valid for API version '{}'". - format(field_name, self._api_version, required_api_version)) - return False - value = dictionary_to_validate.get(field_name) - - if (optional or force_optional) and value is None: - self.log.debug("The optional field '{}' is missing. That's perfectly OK.". - format(full_field_path)) - return False - - # Certainly down from here the field is present (value is not None) - # so we should only return True from now on - - self._sanity_checks(children_validation_specs=children_validation_specs, - field_type=field_type, - full_field_path=full_field_path, - regexp=regexp, - custom_validation=custom_validation, - value=value) - - if regexp: - self._validate_regexp(full_field_path, regexp, value) - elif field_type == 'dict': - if not isinstance(value, dict): - raise FieldValidationException( - "The field '{}' should be dictionary type according to " - "specification '{}' but it is '{}'". - format(full_field_path, validation_spec, value)) - if children_validation_specs is None: - self.log.debug( - "The dict field '{}' has no nested fields defined in the " - "specification '{}'. That's perfectly ok - it's content will " - "not be validated." - .format(full_field_path, validation_spec)) - else: - self._validate_dict(children_validation_specs, full_field_path, value) - elif field_type == 'union': - if not children_validation_specs: - raise ValidationSpecificationException( - "The union field '{}' has no nested fields " - "defined in specification '{}'. Unions should have at least one " - "nested field defined.".format(full_field_path, validation_spec)) - self._validate_union(children_validation_specs, full_field_path, - dictionary_to_validate) - elif custom_validation: - try: - custom_validation(value) - except Exception as e: - raise FieldValidationException( - "Error while validating custom field '{}' specified by '{}': '{}'". - format(full_field_path, validation_spec, e)) - elif field_type is None: - self.log.debug("The type of field '{}' is not specified in '{}'. " - "Not validating its content.". - format(full_field_path, validation_spec)) - else: - raise ValidationSpecificationException( - "The field '{}' is of type '{}' in specification '{}'." - "This type is unknown to validation!".format( - full_field_path, field_type, validation_spec)) - return True - - def validate(self, body_to_validate): - """ - Validates if the body (dictionary) follows specification that the validator was - instantiated with. Raises ValidationSpecificationException or - ValidationFieldException in case of problems with specification or the - body not conforming to the specification respectively. - :param body_to_validate: body that must follow the specification - :type body_to_validate: dict - :return: None - """ - try: - for validation_spec in self._validation_specs: - self._validate_field(validation_spec=validation_spec, - dictionary_to_validate=body_to_validate) - except FieldValidationException as e: - raise FieldValidationException( - "There was an error when validating: field '{}': '{}'". - format(body_to_validate, e)) - -# TODO End of field validator to be extracted - def _validate_available_memory_in_mb(value): if int(value) <= 0: - raise FieldValidationException("The available memory has to be greater than 0") + raise GcpFieldValidationException("The available memory has to be greater than 0") def _validate_max_instances(value): if int(value) <= 0: - raise FieldValidationException( + raise GcpFieldValidationException( "The max instances parameter has to be greater than 0") @@ -378,9 +124,10 @@ def __init__(self, self.api_version = api_version self.zip_path = zip_path self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path) - self.validate_body = validate_body - self._field_validator = BodyFieldValidator(CLOUD_FUNCTION_VALIDATION, - api_version=api_version) + self._field_validator = None + if validate_body: + self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION, + api_version=api_version) self._hook = GcfHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version) self._validate_inputs() super(GcfFunctionDeployOperator, self).__init__(*args, **kwargs) @@ -395,7 +142,8 @@ def _validate_inputs(self): self.zip_path_preprocessor.preprocess_body() def _validate_all_body_fields(self): - self._field_validator.validate(self.body) + if self._field_validator: + self._field_validator.validate(self.body) def _create_new_function(self): self._hook.create_new_function(self.full_location, self.body) @@ -406,8 +154,8 @@ def _update_function(self): def _check_if_function_exists(self): name = self.body.get('name') if not name: - raise FieldValidationException("The 'name' field should be present in " - "body: '{}'.".format(self.body)) + raise GcpFieldValidationException("The 'name' field should be present in " + "body: '{}'.".format(self.body)) try: self._hook.get_function(name) except HttpError as e: @@ -430,8 +178,7 @@ def _set_airflow_version_label(self): def execute(self, context): if self.zip_path_preprocessor.should_upload_function(): self.body[SOURCE_UPLOAD_URL] = self._upload_source_code() - if self.validate_body: - self._validate_all_body_fields() + self._validate_all_body_fields() self._set_airflow_version_label() if not self._check_if_function_exists(): self._create_new_function() diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py new file mode 100644 index 0000000000000..20f72d94b813a --- /dev/null +++ b/airflow/contrib/utils/gcp_field_validator.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Validator for body fields sent via GCP API. + +The validator performs validation of the body (being dictionary of fields) that +is sent in the API request to Google Cloud (via googleclient API usually). + +Context +------- +The specification mostly focuses on helping Airflow DAG developers in the development +phase. You can build your own GCP operator (such as GcfDeployOperator for example) which +can have built-in validation specification for the particular API. It's super helpful +when developer plays with different fields and their values at the initial phase of +DAG development. Most of the Google Cloud APIs perform their own validation on the +server side, but most of the requests are asynchronous and you need to wait for result +of the operation. This takes precious times and slows +down iteration over the API. BodyFieldValidator is meant to be used on the client side +and it should therefore provide an instant feedback to the developer on misspelled or +wrong type of parameters. + +The validation should be performed in "execute()" method call in order to allow +template parameters to be expanded before validation is performed. + +Types of fields +--------------- + +Specification is an array of dictionaries - each dictionary describes field, its type, +validation, optionality, api_version supported and nested fields (for unions and dicts). + +Typically (for clarity and in order to aid syntax highlighting) the array of +dicts should be defined as series of dict() executions. Fragment of example +specification might look as follows: + +``` +SPECIFICATION =[ + dict(name="an_union", type="union", optional=True, fields=[ + dict(name="variant_1", type="dict"), + dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), + ), + dict(name="an_union", type="dict", fields=[ + dict(name="field_1", type="dict"), + dict(name="field_2", regexp=r'^.+$'), + ), + ... +] +``` + +Each field should have key = "name" indicating field name. The field can be of one of the +following types: + +* Dict fields: (key = "type", value="dict"): + Field of this type should contain nested fields in form of an array of dicts. + Each of the fields in the array is then expected (unless marked as optional) + and validated recursively. If an extra field is present in the dictionary, warning is + printed in log file (but the validation succeeds - see the Forward-compatibility notes) +* Union fields (key = "type", value="union"): field of this type should contain nested + fields in form of an array of dicts. One of the fields (and only one) should be + present (unless the union is marked as optional). If more than one union field is + present, FieldValidationException is raised. If none of the union fields is + present - warning is printed in the log (see below Forward-compatibility notes). +* Regexp-validated fields: (key = "regexp") - fields of this type are assumed to be + strings and they are validated with the regexp specified. Remember that the regexps + should ideally contain ^ at the beginning and $ at the end to make sure that + the whole field content is validated. Typically such regexp + validations should be used carefully and sparingly (see Forward-compatibility + notes below). Most of regexp validation should be at most r'^.+$'. +* Custom-validated fields: (key = "custom_validation") - fields of this type are validated + using method specified via custom_validation field. Any exception thrown in the custom + validation will be turned into FieldValidationException and will cause validation to + fail. Such custom validations might be used to check numeric fields (including + ranges of values), booleans or any other types of fields. +* API version: (key="api_version") if API version is specified, then the field will only + be validated when api_version used at field validator initialization matches exactly the + the version specified. If you want to declare fields that are available in several + versions of the APIs, you should specify the field as many times as many API versions + should be supported (each time with different API version). +* if none of the keys ("type", "regexp", "custom_validation" - the field is not validated + +You can see some of the field examples in EXAMPLE_VALIDATION_SPECIFICATION. + + +Forward-compatibility notes +--------------------------- +Certain decisions are crucial to allow the client APIs to work also with future API +versions. Since body attached is passed to the API’s call, this is entirely +possible to pass-through any new fields in the body (for future API versions) - +albeit without validation on the client side - they can and will still be validated +on the server side usually. + +Here are the guidelines that you should follow to make validation forward-compatible: + +* most of the fields are not validated for their content. It's possible to use regexp + in some specific cases that are guaranteed not to change in the future, but for most + fields regexp validation should be r'^.+$' indicating check for non-emptiness +* api_version is not validated - user can pass any future version of the api here. The API + version is only used to filter parameters that are marked as present in this api version + any new (not present in the specification) fields in the body are allowed (not verified) + For dictionaries, new fields can be added to dictionaries by future calls. However if an + unknown field in dictionary is added, a warning is logged by the client (but validation + remains successful). This is very nice feature to protect against typos in names. +* For unions, newly added union variants can be added by future calls and they will + pass validation, however the content or presence of those fields will not be validated. + This means that it’s possible to send a new non-validated union field together with an + old validated field and this problem will not be detected by the client. In such case + warning will be printed. +* When you add validator to an operator, you should also add ``validate_body`` parameter + (default = True) to __init__ of such operators - when it is set to False, + no validation should be performed. This is a safeguard for totally unpredicted and + backwards-incompatible changes that might sometimes occur in the APIs. + +""" + +import re + +from airflow import LoggingMixin, AirflowException + +COMPOSITE_FIELD_TYPES = ['union', 'dict'] + + +class GcpFieldValidationException(AirflowException): + """Thrown when validation finds dictionary field not valid according to specification. + """ + + def __init__(self, message): + super(GcpFieldValidationException, self).__init__(message) + + +class GcpValidationSpecificationException(AirflowException): + """Thrown when validation specification is wrong. + + This should only happen during development as ideally + specification itself should not be invalid ;) . + """ + + def __init__(self, message): + super(GcpValidationSpecificationException, self).__init__(message) + + +def _int_greater_than_zero(value): + if int(value) <= 0: + raise GcpFieldValidationException("The available memory has to be greater than 0") + + +EXAMPLE_VALIDATION_SPECIFICATION = [ + dict(name="name", regexp="^.+$"), + dict(name="description", regexp="^.+$", optional=True), + dict(name="availableMemoryMb", custom_validation=_int_greater_than_zero, + optional=True), + dict(name="labels", optional=True, type="dict"), + dict(name="an_union", type="union", fields=[ + dict(name="variant_1", regexp=r'^.+$'), + dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), + dict(name="variant_3", type="dict", fields=[ + dict(name="url", regexp=r'^.+$') + ]), + dict(name="variant_4") + ]), +] + + +class GcpBodyFieldValidator(LoggingMixin): + """Validates correctness of request body according to specification. + + The specification can describe various type of + fields including custom validation, and union of fields. This validator is + to be reusable by various operators. See the EXAMPLE_VALIDATION_SPECIFICATION + for some examples and explanations of how to create specification. + + :param validation_specs: dictionary describing validation specification + :type validation_specs: [dict] + :param api_version: Version of the api used (for example v1) + :type api_version: str + + """ + def __init__(self, validation_specs, api_version): + # type: ([dict], str) -> None + super(GcpBodyFieldValidator, self).__init__() + self._validation_specs = validation_specs + self._api_version = api_version + + @staticmethod + def _get_field_name_with_parent(field_name, parent): + if parent: + return parent + '.' + field_name + return field_name + + @staticmethod + def _sanity_checks(children_validation_specs, field_type, full_field_path, + regexp, custom_validation, value): + # type: (dict, str, str, str, function, object) -> None + if value is None and field_type != 'union': + raise GcpFieldValidationException( + "The required body field '{}' is missing. Please add it.". + format(full_field_path)) + if regexp and field_type: + raise GcpValidationSpecificationException( + "The validation specification entry '{}' has both type and regexp. " + "The regexp is only allowed without type (i.e. assume type is 'str' " + "that can be validated with regexp)".format(full_field_path)) + if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES: + raise GcpValidationSpecificationException( + "Nested fields are specified in field '{}' of type '{}'. " + "Nested fields are only allowed for fields of those types: ('{}').". + format(full_field_path, field_type, COMPOSITE_FIELD_TYPES)) + if custom_validation and field_type: + raise GcpValidationSpecificationException( + "The validation specification field '{}' has both type and " + "custom_validation. Custom validation is only allowed without type.". + format(full_field_path)) + + @staticmethod + def _validate_regexp(full_field_path, regexp, value): + # type: (str, str, str) -> None + if not re.match(regexp, value): + # Note matching of only the beginning as we assume the regexps all-or-nothing + raise GcpFieldValidationException( + "The body field '{}' of value '{}' does not match the field " + "specification regexp: '{}'.". + format(full_field_path, value, regexp)) + + def _validate_dict(self, children_validation_specs, full_field_path, value): + # type: (dict, str, dict) -> None + for child_validation_spec in children_validation_specs: + self._validate_field(validation_spec=child_validation_spec, + dictionary_to_validate=value, + parent=full_field_path) + all_dict_keys = [spec['name'] for spec in children_validation_specs] + for field_name in value.keys(): + if field_name not in all_dict_keys: + self.log.warning( + "The field '{}' is in the body, but is not specified in the " + "validation specification '{}'. " + "This might be because you are using newer API version and " + "new field names defined for that version. Then the warning " + "can be safely ignored, or you might want to upgrade the operator" + "to the version that supports the new API version.".format( + self._get_field_name_with_parent(field_name, full_field_path), + children_validation_specs)) + + def _validate_union(self, children_validation_specs, full_field_path, + dictionary_to_validate): + # type: (dict, str, dict) -> None + field_found = False + found_field_name = None + for child_validation_spec in children_validation_specs: + # Forcing optional so that we do not have to type optional = True + # in specification for all union fields + new_field_found = self._validate_field( + validation_spec=child_validation_spec, + dictionary_to_validate=dictionary_to_validate, + parent=full_field_path, + force_optional=True) + field_name = child_validation_spec['name'] + if new_field_found and field_found: + raise GcpFieldValidationException( + "The mutually exclusive fields '{}' and '{}' belonging to the " + "union '{}' are both present. Please remove one". + format(field_name, found_field_name, full_field_path)) + if new_field_found: + field_found = True + found_field_name = field_name + if not field_found: + self.log.warning( + "There is no '{}' union defined in the body {}. " + "Validation expected one of '{}' but could not find any. It's possible " + "that you are using newer API version and there is another union variant " + "defined for that version. Then the warning can be safely ignored, " + "or you might want to upgrade the operator to the version that " + "supports the new API version.".format( + full_field_path, + dictionary_to_validate, + [field['name'] for field in children_validation_specs])) + + def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, + force_optional=False): + """ + Validates if field is OK. + :param validation_spec: specification of the field + :type validation_spec: dict + :param dictionary_to_validate: dictionary where the field should be present + :type dictionary_to_validate: dict + :param parent: full path of parent field + :type parent: str + :param force_optional: forces the field to be optional + (all union fields have force_optional set to True) + :type force_optional: bool + :return: True if the field is present + """ + field_name = validation_spec['name'] + field_type = validation_spec.get('type') + optional = validation_spec.get('optional') + regexp = validation_spec.get('regexp') + children_validation_specs = validation_spec.get('fields') + required_api_version = validation_spec.get('api_version') + custom_validation = validation_spec.get('custom_validation') + + full_field_path = self._get_field_name_with_parent(field_name=field_name, + parent=parent) + if required_api_version and required_api_version != self._api_version: + self.log.debug( + "Skipping validation of the field '{}' for API version '{}' " + "as it is only valid for API version '{}'". + format(field_name, self._api_version, required_api_version)) + return False + value = dictionary_to_validate.get(field_name) + + if (optional or force_optional) and value is None: + self.log.debug("The optional field '{}' is missing. That's perfectly OK.". + format(full_field_path)) + return False + + # Certainly down from here the field is present (value is not None) + # so we should only return True from now on + + self._sanity_checks(children_validation_specs=children_validation_specs, + field_type=field_type, + full_field_path=full_field_path, + regexp=regexp, + custom_validation=custom_validation, + value=value) + + if regexp: + self._validate_regexp(full_field_path, regexp, value) + elif field_type == 'dict': + if not isinstance(value, dict): + raise GcpFieldValidationException( + "The field '{}' should be dictionary type according to " + "specification '{}' but it is '{}'". + format(full_field_path, validation_spec, value)) + if children_validation_specs is None: + self.log.debug( + "The dict field '{}' has no nested fields defined in the " + "specification '{}'. That's perfectly ok - it's content will " + "not be validated." + .format(full_field_path, validation_spec)) + else: + self._validate_dict(children_validation_specs, full_field_path, value) + elif field_type == 'union': + if not children_validation_specs: + raise GcpValidationSpecificationException( + "The union field '{}' has no nested fields " + "defined in specification '{}'. Unions should have at least one " + "nested field defined.".format(full_field_path, validation_spec)) + self._validate_union(children_validation_specs, full_field_path, + dictionary_to_validate) + elif custom_validation: + try: + custom_validation(value) + except Exception as e: + raise GcpFieldValidationException( + "Error while validating custom field '{}' specified by '{}': '{}'". + format(full_field_path, validation_spec, e)) + elif field_type is None: + self.log.debug("The type of field '{}' is not specified in '{}'. " + "Not validating its content.". + format(full_field_path, validation_spec)) + else: + raise GcpValidationSpecificationException( + "The field '{}' is of type '{}' in specification '{}'." + "This type is unknown to validation!".format( + full_field_path, field_type, validation_spec)) + return True + + def validate(self, body_to_validate): + """ + Validates if the body (dictionary) follows specification that the validator was + instantiated with. Raises ValidationSpecificationException or + ValidationFieldException in case of problems with specification or the + body not conforming to the specification respectively. + :param body_to_validate: body that must follow the specification + :type body_to_validate: dict + :return: None + """ + try: + for validation_spec in self._validation_specs: + self._validate_field(validation_spec=validation_spec, + dictionary_to_validate=body_to_validate) + except GcpFieldValidationException as e: + raise GcpFieldValidationException( + "There was an error when validating: body '{}': '{}'". + format(body_to_validate, e)) + all_field_names = [spec['name'] for spec in self._validation_specs + if spec.get('type') != 'union' and + spec.get('api_version') != self._api_version] + all_union_fields = [spec for spec in self._validation_specs + if spec.get('type') == 'union'] + for union_field in all_union_fields: + all_field_names.extend( + [nested_union_spec['name'] for nested_union_spec in union_field['fields'] + if nested_union_spec.get('type') != 'union' and + nested_union_spec.get('api_version') != self._api_version]) + for field_name in body_to_validate.keys(): + if field_name not in all_field_names: + self.log.warning( + "The field '{}' is in the body, but is not specified in the " + "validation specification '{}'. " + "This might be because you are using newer API version and " + "new field names defined for that version. Then the warning " + "can be safely ignored, or you might want to upgrade the operator" + "to the version that supports the new API version.".format( + field_name, self._validation_specs)) diff --git a/docs/howto/operator.rst b/docs/howto/operator.rst index 0d973a391cc7b..549d67757049f 100644 --- a/docs/howto/operator.rst +++ b/docs/howto/operator.rst @@ -102,6 +102,62 @@ to execute a BigQuery load job. :start-after: [START howto_operator_gcs_to_bq] :end-before: [END howto_operator_gcs_to_bq] +GceInstanceStartOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +Allows to start an existing Google Compute Engine instance. + +In this example parameter values are extracted from Airflow variables. +Moreover, the ``default_args`` dict is used to pass common arguments to all operators in a single DAG. + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py + :language: python + :start-after: [START howto_operator_gce_args] + :end-before: [END howto_operator_gce_args] + + +Define the :class:`~airflow.contrib.operators.gcp_compute_operator +.GceInstanceStartOperator` by passing the required arguments to the constructor. + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gce_start] + :end-before: [END howto_operator_gce_start] + +GceInstanceStopOperator +^^^^^^^^^^^^^^^^^^^^^^^ + +Allows to stop an existing Google Compute Engine instance. + +For parameter definition take a look at :class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator` above. + +Define the :class:`~airflow.contrib.operators.gcp_compute_operator +.GceInstanceStopOperator` by passing the required arguments to the constructor. + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gce_stop] + :end-before: [END howto_operator_gce_stop] + +GceSetMachineTypeOperator +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Allows to change the machine type for a stopped instance to the specified machine type. + +For parameter definition take a look at :class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator` above. + +Define the :class:`~airflow.contrib.operators.gcp_compute_operator +.GceSetMachineTypeOperator` by passing the required arguments to the constructor. + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gce_set_machine_type] + :end-before: [END howto_operator_gce_set_machine_type] + + GcfFunctionDeleteOperator ^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/integration.rst b/docs/integration.rst index 6ef7bd8398705..3a5a3c3e0507d 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -457,6 +457,37 @@ BigQueryHook .. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook :members: +Compute Engine +'''''''''''''' + +Compute Engine Operators +"""""""""""""""""""""""" + +- :ref:`GceInstanceStartOperator` : start an existing Google Compute Engine instance. +- :ref:`GceInstanceStopOperator` : stop an existing Google Compute Engine instance. +- :ref:`GceSetMachineTypeOperator` : change the machine type for a stopped instance. + +.. _GceInstanceStartOperator: + +GceInstanceStartOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator + +.. _GceInstanceStopOperator: + +GceInstanceStopOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceInstanceStopOperator + +.. _GceSetMachineTypeOperator: + +GceSetMachineTypeOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceSetMachineTypeOperator + Cloud Functions ''''''''''''''' diff --git a/tests/contrib/operators/test_gcp_compute_operator.py b/tests/contrib/operators/test_gcp_compute_operator.py new file mode 100644 index 0000000000000..449c4e015fdda --- /dev/null +++ b/tests/contrib/operators/test_gcp_compute_operator.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import ast +import unittest + +from airflow import AirflowException, configuration +from airflow.contrib.operators.gcp_compute_operator import GceInstanceStartOperator, \ + GceInstanceStopOperator, GceSetMachineTypeOperator +from airflow.models import TaskInstance, DAG +from airflow.utils import timezone + +try: + # noinspection PyProtectedMember + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +PROJECT_ID = 'project-id' +LOCATION = 'zone' +RESOURCE_ID = 'resource-id' +SHORT_MACHINE_TYPE_NAME = 'n1-machine-type' +SET_MACHINE_TYPE_BODY = { + 'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION, SHORT_MACHINE_TYPE_NAME) +} + +DEFAULT_DATE = timezone.datetime(2017, 1, 1) + + +class GceInstanceStartTest(unittest.TestCase): + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_instance_start(self, mock_hook): + mock_hook.return_value.start_instance.return_value = True + op = GceInstanceStartOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=RESOURCE_ID, + task_id='id' + ) + result = op.execute(None) + mock_hook.assert_called_once_with(api_version='v1', + gcp_conn_id='google_cloud_default') + mock_hook.return_value.start_instance.assert_called_once_with( + PROJECT_ID, LOCATION, RESOURCE_ID + ) + self.assertTrue(result) + + # Setting all of the operator's input parameters as templated dag_ids + # (could be anything else) just to test if the templating works for all fields + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_instance_start_with_templates(self, mock_hook): + dag_id = 'test_dag_id' + configuration.load_test_config() + args = { + 'start_date': DEFAULT_DATE + } + self.dag = DAG(dag_id, default_args=args) + op = GceInstanceStartOperator( + project_id='{{ dag.dag_id }}', + zone='{{ dag.dag_id }}', + resource_id='{{ dag.dag_id }}', + gcp_conn_id='{{ dag.dag_id }}', + api_version='{{ dag.dag_id }}', + task_id='id', + dag=self.dag + ) + ti = TaskInstance(op, DEFAULT_DATE) + ti.render_templates() + self.assertEqual(dag_id, getattr(op, 'project_id')) + self.assertEqual(dag_id, getattr(op, 'zone')) + self.assertEqual(dag_id, getattr(op, 'resource_id')) + self.assertEqual(dag_id, getattr(op, 'gcp_conn_id')) + self.assertEqual(dag_id, getattr(op, 'api_version')) + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_start_should_throw_ex_when_missing_project_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStartOperator( + project_id="", + zone=LOCATION, + resource_id=RESOURCE_ID, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'project_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_start_should_throw_ex_when_missing_zone(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStartOperator( + project_id=PROJECT_ID, + zone="", + resource_id=RESOURCE_ID, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'zone' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStartOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id="", + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'resource_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_instance_stop(self, mock_hook): + mock_hook.return_value.stop_instance.return_value = True + op = GceInstanceStopOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=RESOURCE_ID, + task_id='id' + ) + result = op.execute(None) + mock_hook.assert_called_once_with(api_version='v1', + gcp_conn_id='google_cloud_default') + mock_hook.return_value.stop_instance.assert_called_once_with( + PROJECT_ID, LOCATION, RESOURCE_ID + ) + self.assertTrue(result) + + # Setting all of the operator's input parameters as templated dag_ids + # (could be anything else) just to test if the templating works for all fields + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_instance_stop_with_templates(self, mock_hook): + dag_id = 'test_dag_id' + configuration.load_test_config() + args = { + 'start_date': DEFAULT_DATE + } + self.dag = DAG(dag_id, default_args=args) + op = GceInstanceStopOperator( + project_id='{{ dag.dag_id }}', + zone='{{ dag.dag_id }}', + resource_id='{{ dag.dag_id }}', + gcp_conn_id='{{ dag.dag_id }}', + api_version='{{ dag.dag_id }}', + task_id='id', + dag=self.dag + ) + ti = TaskInstance(op, DEFAULT_DATE) + ti.render_templates() + self.assertEqual(dag_id, getattr(op, 'project_id')) + self.assertEqual(dag_id, getattr(op, 'zone')) + self.assertEqual(dag_id, getattr(op, 'resource_id')) + self.assertEqual(dag_id, getattr(op, 'gcp_conn_id')) + self.assertEqual(dag_id, getattr(op, 'api_version')) + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStopOperator( + project_id="", + zone=LOCATION, + resource_id=RESOURCE_ID, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'project_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_stop_should_throw_ex_when_missing_zone(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStopOperator( + project_id=PROJECT_ID, + zone="", + resource_id=RESOURCE_ID, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'zone' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceInstanceStopOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id="", + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'resource_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type(self, mock_hook): + mock_hook.return_value.set_machine_type.return_value = True + op = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=RESOURCE_ID, + body=SET_MACHINE_TYPE_BODY, + task_id='id' + ) + result = op.execute(None) + mock_hook.assert_called_once_with(api_version='v1', + gcp_conn_id='google_cloud_default') + mock_hook.return_value.set_machine_type.assert_called_once_with( + PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY + ) + self.assertTrue(result) + + # Setting all of the operator's input parameters as templated dag_ids + # (could be anything else) just to test if the templating works for all fields + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type_with_templates(self, mock_hook): + dag_id = 'test_dag_id' + configuration.load_test_config() + args = { + 'start_date': DEFAULT_DATE + } + self.dag = DAG(dag_id, default_args=args) + op = GceSetMachineTypeOperator( + project_id='{{ dag.dag_id }}', + zone='{{ dag.dag_id }}', + resource_id='{{ dag.dag_id }}', + body={}, + gcp_conn_id='{{ dag.dag_id }}', + api_version='{{ dag.dag_id }}', + task_id='id', + dag=self.dag + ) + ti = TaskInstance(op, DEFAULT_DATE) + ti.render_templates() + self.assertEqual(dag_id, getattr(op, 'project_id')) + self.assertEqual(dag_id, getattr(op, 'zone')) + self.assertEqual(dag_id, getattr(op, 'resource_id')) + self.assertEqual(dag_id, getattr(op, 'gcp_conn_id')) + self.assertEqual(dag_id, getattr(op, 'api_version')) + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceSetMachineTypeOperator( + project_id="", + zone=LOCATION, + resource_id=RESOURCE_ID, + body=SET_MACHINE_TYPE_BODY, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'project_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type_should_throw_ex_when_missing_zone(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone="", + resource_id=RESOURCE_ID, + body=SET_MACHINE_TYPE_BODY, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'zone' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id="", + body=SET_MACHINE_TYPE_BODY, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn("The required parameter 'resource_id' is missing", str(err)) + mock_hook.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook') + def test_set_machine_type_should_throw_ex_when_missing_machine_type(self, mock_hook): + with self.assertRaises(AirflowException) as cm: + op = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=RESOURCE_ID, + body={}, + task_id='id' + ) + op.execute(None) + err = cm.exception + self.assertIn( + "The required body field 'machineType' is missing. Please add it.", str(err)) + mock_hook.assert_called_once_with(api_version='v1', + gcp_conn_id='google_cloud_default') + + MOCK_OP_RESPONSE = "{'kind': 'compute#operation', 'id': '8529919847974922736', " \ + "'name': " \ + "'operation-1538578207537-577542784f769-7999ab71-94f9ec1d', " \ + "'zone': 'https://www.googleapis.com/compute/v1/projects/polidea" \ + "-airflow/zones/europe-west3-b', 'operationType': " \ + "'setMachineType', 'targetLink': " \ + "'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \ + "/zones/europe-west3-b/instances/pa-1', 'targetId': " \ + "'2480086944131075860', 'status': 'DONE', 'user': " \ + "'uberdarek@polidea-airflow.iam.gserviceaccount.com', " \ + "'progress': 100, 'insertTime': '2018-10-03T07:50:07.951-07:00', "\ + "'startTime': '2018-10-03T07:50:08.324-07:00', 'endTime': " \ + "'2018-10-03T07:50:08.484-07:00', 'error': {'errors': [{'code': " \ + "'UNSUPPORTED_OPERATION', 'message': \"Machine type with name " \ + "'machine-type-1' does not exist in zone 'europe-west3-b'.\"}]}, "\ + "'httpErrorStatusCode': 400, 'httpErrorMessage': 'BAD REQUEST', " \ + "'selfLink': " \ + "'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \ + "/zones/europe-west3-b/operations/operation-1538578207537" \ + "-577542784f769-7999ab71-94f9ec1d'} " + + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook' + '._check_operation_status') + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook' + '._execute_set_machine_type') + @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook.get_conn') + def test_set_machine_type_should_handle_and_trim_gce_error( + self, get_conn, _execute_set_machine_type, _check_operation_status): + get_conn.return_value = {} + _execute_set_machine_type.return_value = {"name": "test-operation"} + _check_operation_status.return_value = ast.literal_eval(self.MOCK_OP_RESPONSE) + with self.assertRaises(AirflowException) as cm: + op = GceSetMachineTypeOperator( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=RESOURCE_ID, + body=SET_MACHINE_TYPE_BODY, + task_id='id' + ) + op.execute(None) + err = cm.exception + _check_operation_status.assert_called_once_with( + {}, "test-operation", PROJECT_ID, LOCATION) + _execute_set_machine_type.assert_called_once_with( + PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY) + # Checking the full message was sometimes failing due to different order + # of keys in the serialized JSON + self.assertIn("400 BAD REQUEST: {", str(err)) # checking the square bracket trim + self.assertIn("UNSUPPORTED_OPERATION", str(err)) diff --git a/tests/contrib/operators/test_gcp_function_operator.py b/tests/contrib/operators/test_gcp_function_operator.py index d7585ae66fdef..4192560dd984c 100644 --- a/tests/contrib/operators/test_gcp_function_operator.py +++ b/tests/contrib/operators/test_gcp_function_operator.py @@ -519,6 +519,23 @@ def test_valid_trigger_union_field(self, trigger, mock_hook): ) mock_hook.reset_mock() + @mock.patch('airflow.contrib.operators.gcp_function_operator.GcfHook') + def test_extra_parameter(self, mock_hook): + mock_hook.return_value.list_functions.return_value = [] + mock_hook.return_value.create_new_function.return_value = True + body = deepcopy(VALID_BODY) + body['extra_parameter'] = 'extra' + op = GcfFunctionDeployOperator( + project_id="test_project_id", + location="test_region", + body=body, + task_id="id" + ) + op.execute(None) + mock_hook.assert_called_once_with(api_version='v1', + gcp_conn_id='google_cloud_default') + mock_hook.reset_mock() + class GcfFunctionDeleteTest(unittest.TestCase): _FUNCTION_NAME = 'projects/project_name/locations/project_location/functions' \