From 4bddc1de03dd764606c68d9680d1cad511d83f24 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 8 Apr 2024 12:00:19 -0700 Subject: [PATCH] Amazon Bedrock - Clean up hook unit tests --- .../amazon/aws/hooks/test_bedrock.py | 47 ++++--------------- 1 file changed, 10 insertions(+), 37 deletions(-) diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 16752477d5631..43b467d549a71 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -16,46 +16,19 @@ # under the License. from __future__ import annotations -from unittest import mock - import pytest -from botocore.exceptions import ClientError from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook -JOB_NAME = "testJobName" -EXPECTED_STATUS = "InProgress" - - -@pytest.fixture -def mock_conn(): - with mock.patch.object(BedrockHook, "conn") as _conn: - _conn.get_model_customization_job.return_value = {"jobName": JOB_NAME, "status": EXPECTED_STATUS} - yield _conn - - -class TestBedrockHook: - VALIDATION_EXCEPTION_ERROR = ClientError( - error_response={"Error": {"Code": "ValidationException", "Message": ""}}, - operation_name="GetModelCustomizationJob", - ) - UNEXPECTED_EXCEPTION = ClientError( - error_response={"Error": {"Code": "ExpiredTokenException", "Message": ""}}, - operation_name="GetModelCustomizationJob", +class TestBedrockHooks: + @pytest.mark.parametrize( + "test_hook, service_name", + [ + pytest.param(BedrockHook(), "bedrock", id="bedrock"), + pytest.param(BedrockRuntimeHook(), "bedrock-runtime", id="bedrock-runtime"), + ], ) - - def setup_method(self): - self.hook = BedrockHook() - - def test_conn_returns_a_boto3_connection(self): - assert self.hook.conn is not None - assert self.hook.conn.meta.service_model.service_name == "bedrock" - - -class TestBedrockRuntimeHook: - def test_conn_returns_a_boto3_connection(self): - hook = BedrockRuntimeHook() - - assert hook.conn is not None - assert hook.conn.meta.service_model.service_name == "bedrock-runtime" + def test_bedrock_hooks(self, test_hook, service_name): + assert test_hook.conn is not None + assert test_hook.conn.meta.service_model.service_name == service_name