From 903a5c43d5ea49a101faad2ede9bd2b12c4a86e1 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 2 May 2022 15:06:15 -0700 Subject: [PATCH 1/2] add validation decorator for multiapi args since inputs changed from v3.x to language api --- .../azure/ai/textanalytics/_check.py | 9 ++ .../azure/ai/textanalytics/_models.py | 11 +- .../ai/textanalytics/_request_handlers.py | 23 ---- .../textanalytics/_text_analytics_client.py | 111 +++++++----------- .../azure/ai/textanalytics/_validate.py | 97 +++++++++++++++ .../azure/ai/textanalytics/_version.py | 1 + .../aio/_text_analytics_client_async.py | 111 +++++++----------- 7 files changed, 196 insertions(+), 167 deletions(-) create mode 100644 sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_validate.py diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_check.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_check.py index 29cf38876dae..5845d98f878e 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_check.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_check.py @@ -11,3 +11,12 @@ def is_language_api(api_version): """Language API is date-based """ return re.search(r'\d{4}-\d{2}-\d{2}', api_version) + + +def string_index_type_compatibility(string_index_type): + """Language API changed this string_index_type option to plural. + Convert singular to plural for language API + """ + if string_index_type == "TextElement_v8": + return "TextElements_v8" + return string_index_type diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py index a9f16fc2db4a..b9112237beec 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py @@ -13,16 +13,7 @@ from ._generated.v3_0 import models as _v3_0_models from ._generated.v3_1 import models as _v3_1_models from ._generated.v2022_03_01_preview import models as _v2022_03_01_preview_models -from ._check import is_language_api - - -def string_index_type_compatibility(string_index_type): - """Language API changed this string_index_type option to plural. - Convert singular to plural for language API - """ - if string_index_type == "TextElement_v8": - return "TextElements_v8" - return string_index_type +from ._check import is_language_api, string_index_type_compatibility def _get_indices(relation): diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py index 26d13efa82a2..4401d441412d 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_request_handlers.py @@ -9,7 +9,6 @@ TextDocumentInput, _AnalyzeActionsType, ) -from ._check import is_language_api def _validate_input(documents, hint, whole_input_hint): @@ -99,25 +98,3 @@ def _determine_action_type(action): # pylint: disable=too-many-return-statement if action.__class__.__name__ == "CustomMultiLabelClassificationLROTask": return _AnalyzeActionsType.MULTI_CATEGORY_CLASSIFY return _AnalyzeActionsType.EXTRACT_KEY_PHRASES - - -def _check_string_index_type_arg( - string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint" -): - string_index_type = None - - if api_version == "v3.0": - if string_index_type_arg is not None: - raise ValueError( - "'string_index_type' is only available for API version V3_1 and up" - ) - elif is_language_api(api_version) and string_index_type_arg == "TextElement_v8": - return "TextElements_v8" - else: - if string_index_type_arg is None: - string_index_type = string_index_type_default - - else: - string_index_type = string_index_type_arg - - return string_index_type diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py index 0545f07b0907..6092cb0fb392 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_text_analytics_client.py @@ -16,13 +16,13 @@ from azure.core.tracing.decorator import distributed_trace from azure.core.exceptions import HttpResponseError from azure.core.credentials import AzureKeyCredential -from ._base_client import TextAnalyticsClientBase, TextAnalyticsApiVersion +from ._base_client import TextAnalyticsClientBase from ._lro import AnalyzeActionsLROPoller, AnalyzeHealthcareEntitiesLROPoller from ._request_handlers import ( _validate_input, _determine_action_type, - _check_string_index_type_arg, ) +from ._validate import validate_multiapi_args, check_for_unsupported_actions_types from ._version import DEFAULT_API_VERSION from ._response_handlers import ( process_http_response_error, @@ -67,7 +67,7 @@ MultiCategoryClassifyResult, _AnalyzeActionsType, ) -from ._check import is_language_api +from ._check import is_language_api, string_index_type_compatibility if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -132,6 +132,10 @@ def __init__( ) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["disable_service_logs"]} + ) def detect_language( self, documents: Union[List[str], List[DetectLanguageInput], List[Dict[str, str]]], @@ -228,6 +232,10 @@ def detect_language( return process_http_response_error(error) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]} + ) def recognize_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -294,11 +302,7 @@ def recognize_entities( docs = _validate_input(documents, "language", language) model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_index_type_default, - ) + string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) disable_service_logs = kwargs.pop("disable_service_logs", None) try: @@ -310,7 +314,7 @@ def recognize_entities( parameters=models.EntitiesTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type) ) ), show_stats=show_stats, @@ -332,6 +336,9 @@ def recognize_entities( return process_http_response_error(error) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.1" + ) def recognize_pii_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -408,11 +415,7 @@ def recognize_pii_entities( show_stats = kwargs.pop("show_stats", None) domain_filter = kwargs.pop("domain_filter", None) categories_filter = kwargs.pop("categories_filter", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_index_type_default, - ) + string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) disable_service_logs = kwargs.pop("disable_service_logs", None) try: @@ -426,7 +429,7 @@ def recognize_pii_entities( model_version=model_version, domain=domain_filter, pii_categories=categories_filter, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type) ) ), show_stats=show_stats, @@ -446,19 +449,14 @@ def recognize_pii_entities( cls=kwargs.pop("cls", pii_entities_result), **kwargs ) - except ValueError as error: - if ( - "API version v3.0 does not have operation 'entities_recognition_pii'" - in str(error) - ): - raise ValueError( - "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" - ) from error - raise error except HttpResponseError as error: return process_http_response_error(error) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]} + ) def recognize_linked_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -527,11 +525,7 @@ def recognize_linked_entities( model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", None) disable_service_logs = kwargs.pop("disable_service_logs", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_index_type_default, - ) + string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) try: if is_language_api(self._api_version): @@ -542,7 +536,7 @@ def recognize_linked_entities( parameters=models.EntityLinkingTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type) ) ), show_stats=show_stats, @@ -581,6 +575,10 @@ def _healthcare_result_callback( ) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.1", + args_mapping={"2022-03-01-preview": ["display_name"]} + ) def begin_analyze_healthcare_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -658,11 +656,7 @@ def begin_analyze_healthcare_entities( show_stats = kwargs.pop("show_stats", None) polling_interval = kwargs.pop("polling_interval", 5) continuation_token = kwargs.pop("continuation_token", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_index_type_default, - ) + string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) disable_service_logs = kwargs.pop("disable_service_logs", None) display_name = kwargs.pop("display_name", None) @@ -710,7 +704,7 @@ def get_result_from_cont_token(initial_response, pipeline_response): parameters=models.HealthcareTaskParameters( model_version=model_version, logging_opt_out=disable_service_logs, - string_index_type=string_index_type, + string_index_type=string_index_type_compatibility(string_index_type) ) ) ] @@ -755,19 +749,14 @@ def get_result_from_cont_token(initial_response, pipeline_response): continuation_token=continuation_token, **kwargs ) - - except ValueError as error: - if "API version v3.0 does not have operation 'begin_health'" in str(error): - raise ValueError( - "'begin_analyze_healthcare_entities' method is only available for API version \ - V3_1 and up." - ) from error - raise error - except HttpResponseError as error: return process_http_response_error(error) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["disable_service_logs"]} + ) def extract_key_phrases( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -862,6 +851,10 @@ def extract_key_phrases( return process_http_response_error(error) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["show_opinion_mining", "disable_service_logs", "string_index_type"]} + ) def analyze_sentiment( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -936,19 +929,7 @@ def analyze_sentiment( show_stats = kwargs.pop("show_stats", None) show_opinion_mining = kwargs.pop("show_opinion_mining", None) disable_service_logs = kwargs.pop("disable_service_logs", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_index_type_default, - ) - if show_opinion_mining is not None: - if ( - self._api_version == TextAnalyticsApiVersion.V3_0 - and show_opinion_mining - ): - raise ValueError( - "'show_opinion_mining' is only available for API version v3.1 and up" - ) + string_index_type = kwargs.pop("string_index_type", self._string_index_type_default) try: if is_language_api(self._api_version): @@ -959,7 +940,7 @@ def analyze_sentiment( parameters=models.SentimentAnalysisTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type, + string_index_type=string_index_type_compatibility(string_index_type), opinion_mining=show_opinion_mining, ) ), @@ -1001,6 +982,10 @@ def _analyze_result_callback( ) @distributed_trace + @validate_multiapi_args( + version_method_added="v3.1", + custom_wrapper=check_for_unsupported_actions_types + ) def begin_analyze_actions( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -1253,13 +1238,5 @@ def get_result_from_cont_token(initial_response, pipeline_response): continuation_token=continuation_token, **kwargs ) - - except ValueError as error: - if "API version v3.0 does not have operation 'begin_analyze'" in str(error): - raise ValueError( - "'begin_analyze_actions' endpoint is only available for API version V3_1 and up" - ) from error - raise error - except HttpResponseError as error: return process_http_response_error(error) diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_validate.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_validate.py new file mode 100644 index 000000000000..f72c8ecfbd32 --- /dev/null +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_validate.py @@ -0,0 +1,97 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +import functools +from ._version import VERSIONS_SUPPORTED + + +def check_for_unsupported_actions_types(*args, **kwargs): + + client = args[0] + # this assumes the client has an _api_version attribute + selected_api_version = client._api_version # pylint: disable=protected-access + + if "actions" not in kwargs: + actions = args[2] + else: + actions = kwargs.get("actions") + + if actions is None: + return + + actions_version_mapping = { + "2022-03-01-preview": + [ + "ExtractSummaryAction", + "RecognizeCustomEntitiesAction", + "SingleCategoryClassifyAction", + "MultiCategoryClassifyAction" + ] + } + + unsupported = { + arg: version + for version, args in actions_version_mapping.items() + for arg in args + if arg in [action.__class__.__name__ for action in actions] + and selected_api_version != version + and VERSIONS_SUPPORTED.index(selected_api_version) < VERSIONS_SUPPORTED.index(version) + } + + if unsupported: + error_strings = [ + f"'{param}' is only available for API version {version} and up.\n" + for param, version in unsupported.items() + ] + raise ValueError("".join(error_strings)) + + +def validate_multiapi_args(**kwargs): + args_mapping = kwargs.pop("args_mapping", None) + version_method_added = kwargs.pop("version_method_added", None) + custom_wrapper = kwargs.pop("custom_wrapper", None) + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + # this assumes the client has an _api_version attribute + client = args[0] + selected_api_version = client._api_version # pylint: disable=protected-access + except AttributeError: + return func(*args, **kwargs) + + # the latest version is selected, we assume all features supported + if selected_api_version == VERSIONS_SUPPORTED[-1]: + return func(*args, **kwargs) + + if version_method_added and version_method_added != selected_api_version and \ + VERSIONS_SUPPORTED.index(selected_api_version) < VERSIONS_SUPPORTED.index(version_method_added): + raise ValueError(f"'{func.__name__}' is only available for API version {version_method_added} and up.") + + if args_mapping: + unsupported = { + arg: version + for version, args in args_mapping.items() + for arg in args + if arg in kwargs.keys() + and selected_api_version != version + and VERSIONS_SUPPORTED.index(selected_api_version) < VERSIONS_SUPPORTED.index(version) + } + if unsupported: + error_strings = [ + f"'{param}' is only available for API version {version} and up.\n" + for param, version in unsupported.items() + ] + raise ValueError("".join(error_strings)) + + if custom_wrapper: + custom_wrapper(*args, **kwargs) + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_version.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_version.py index 6d9e5f89708d..4d5335d4c074 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_version.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_version.py @@ -5,3 +5,4 @@ VERSION = "5.2.0b4" DEFAULT_API_VERSION = "2022-03-01-preview" +VERSIONS_SUPPORTED = ("v3.0", "v3.1", "2022-03-01-preview") diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py index c663275845c3..4986a6f7d529 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/aio/_text_analytics_client_async.py @@ -10,14 +10,13 @@ from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.exceptions import HttpResponseError from azure.core.credentials import AzureKeyCredential -from .._version import DEFAULT_API_VERSION from ._base_client_async import AsyncTextAnalyticsClientBase -from .._base_client import TextAnalyticsApiVersion from .._request_handlers import ( _validate_input, _determine_action_type, - _check_string_index_type_arg, ) +from .._validate import validate_multiapi_args, check_for_unsupported_actions_types +from .._version import DEFAULT_API_VERSION from .._response_handlers import ( process_http_response_error, entities_result, @@ -54,7 +53,7 @@ MultiCategoryClassifyAction, MultiCategoryClassifyResult, ) -from .._check import is_language_api +from .._check import is_language_api, string_index_type_compatibility from .._lro import TextAnalyticsOperationResourcePolling from ._lro_async import ( AsyncAnalyzeHealthcareEntitiesLROPollingMethod, @@ -127,6 +126,10 @@ def __init__( ) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["disable_service_logs"]} + ) async def detect_language( self, documents: Union[List[str], List[DetectLanguageInput], List[Dict[str, str]]], @@ -223,6 +226,10 @@ async def detect_language( return process_http_response_error(error) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]} + ) async def recognize_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -290,11 +297,7 @@ async def recognize_entities( model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", None) disable_service_logs = kwargs.pop("disable_service_logs", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_code_unit, - ) + string_index_type = kwargs.pop("string_index_type", self._string_code_unit) try: if is_language_api(self._api_version): @@ -305,7 +308,7 @@ async def recognize_entities( parameters=models.EntitiesTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type), ) ), show_stats=show_stats, @@ -327,6 +330,9 @@ async def recognize_entities( return process_http_response_error(error) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.1" + ) async def recognize_pii_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -403,11 +409,7 @@ async def recognize_pii_entities( show_stats = kwargs.pop("show_stats", None) domain_filter = kwargs.pop("domain_filter", None) categories_filter = kwargs.pop("categories_filter", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_code_unit, - ) + string_index_type = kwargs.pop("string_index_type", self._string_code_unit) disable_service_logs = kwargs.pop("disable_service_logs", None) try: @@ -421,7 +423,7 @@ async def recognize_pii_entities( model_version=model_version, domain=domain_filter, pii_categories=categories_filter, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type), ) ), show_stats=show_stats, @@ -441,19 +443,14 @@ async def recognize_pii_entities( cls=kwargs.pop("cls", pii_entities_result), **kwargs, ) - except ValueError as error: - if ( - "API version v3.0 does not have operation 'entities_recognition_pii'" - in str(error) - ): - raise ValueError( - "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" - ) from error - raise error except HttpResponseError as error: return process_http_response_error(error) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]} + ) async def recognize_linked_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -522,11 +519,7 @@ async def recognize_linked_entities( model_version = kwargs.pop("model_version", None) show_stats = kwargs.pop("show_stats", None) disable_service_logs = kwargs.pop("disable_service_logs", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_code_unit, - ) + string_index_type = kwargs.pop("string_index_type", self._string_code_unit) try: if is_language_api(self._api_version): @@ -537,7 +530,7 @@ async def recognize_linked_entities( parameters=models.EntityLinkingTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type + string_index_type=string_index_type_compatibility(string_index_type), ) ), show_stats=show_stats, @@ -559,6 +552,10 @@ async def recognize_linked_entities( return process_http_response_error(error) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["disable_service_logs"]} + ) async def extract_key_phrases( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -653,6 +650,10 @@ async def extract_key_phrases( return process_http_response_error(error) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.0", + args_mapping={"v3.1": ["show_opinion_mining", "disable_service_logs", "string_index_type"]} + ) async def analyze_sentiment( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -727,19 +728,7 @@ async def analyze_sentiment( show_stats = kwargs.pop("show_stats", None) show_opinion_mining = kwargs.pop("show_opinion_mining", None) disable_service_logs = kwargs.pop("disable_service_logs", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_code_unit, - ) - if show_opinion_mining is not None: - if ( - self._api_version == TextAnalyticsApiVersion.V3_0 - and show_opinion_mining - ): - raise ValueError( - "'show_opinion_mining' is only available for API version v3.1 and up" - ) + string_index_type = kwargs.pop("string_index_type", self._string_code_unit) try: if is_language_api(self._api_version): @@ -750,7 +739,7 @@ async def analyze_sentiment( parameters=models.SentimentAnalysisTaskParameters( logging_opt_out=disable_service_logs, model_version=model_version, - string_index_type=string_index_type, + string_index_type=string_index_type_compatibility(string_index_type), opinion_mining=show_opinion_mining, ) ), @@ -791,6 +780,10 @@ def _healthcare_result_callback( ) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.1", + args_mapping={"2022-03-01-preview": ["display_name"]} + ) async def begin_analyze_healthcare_entities( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -869,11 +862,7 @@ async def begin_analyze_healthcare_entities( show_stats = kwargs.pop("show_stats", None) polling_interval = kwargs.pop("polling_interval", 5) continuation_token = kwargs.pop("continuation_token", None) - string_index_type = _check_string_index_type_arg( - kwargs.pop("string_index_type", None), - self._api_version, - string_index_type_default=self._string_code_unit, - ) + string_index_type = kwargs.pop("string_index_type", self._string_code_unit) disable_service_logs = kwargs.pop("disable_service_logs", None) display_name = kwargs.pop("display_name", None) @@ -921,7 +910,7 @@ def get_result_from_cont_token(initial_response, pipeline_response): parameters=models.HealthcareTaskParameters( model_version=model_version, logging_opt_out=disable_service_logs, - string_index_type=string_index_type, + string_index_type=string_index_type_compatibility(string_index_type), ) ) ] @@ -966,14 +955,6 @@ def get_result_from_cont_token(initial_response, pipeline_response): continuation_token=continuation_token, **kwargs, ) - - except ValueError as error: - if "API version v3.0 does not have operation 'begin_health'" in str(error): - raise ValueError( - "'begin_analyze_healthcare_entities' endpoint is only available for API version V3_1 and up" - ) from error - raise error - except HttpResponseError as error: return process_http_response_error(error) @@ -996,6 +977,10 @@ def _analyze_result_callback( ) @distributed_trace_async + @validate_multiapi_args( + version_method_added="v3.1", + custom_wrapper=check_for_unsupported_actions_types + ) async def begin_analyze_actions( self, documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]], @@ -1250,13 +1235,5 @@ def get_result_from_cont_token(initial_response, pipeline_response): continuation_token=continuation_token, **kwargs, ) - - except ValueError as error: - if "API version v3.0 does not have operation 'begin_analyze'" in str(error): - raise ValueError( - "'begin_analyze_actions' endpoint is only available for API version V3_1 and up" - ) from error - raise error - except HttpResponseError as error: return process_http_response_error(error) From 8041496ae6f56fba32fe8d64e10db2c4ac932285 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 2 May 2022 15:06:28 -0700 Subject: [PATCH 2/2] add tests --- .../tests/test_analyze.py | 68 ++++++++++++++++++ .../tests/test_analyze_async.py | 69 +++++++++++++++++++ .../tests/test_analyze_healthcare.py | 39 +++++++++++ .../tests/test_analyze_healthcare_async.py | 39 +++++++++++ .../tests/test_analyze_sentiment.py | 37 +++++----- .../tests/test_analyze_sentiment_async.py | 29 +++++--- .../tests/test_detect_language.py | 9 +++ .../tests/test_detect_language_async.py | 9 +++ .../tests/test_extract_key_phrases.py | 9 +++ .../tests/test_extract_key_phrases_async.py | 9 +++ .../tests/test_recognize_entities.py | 33 ++++----- .../tests/test_recognize_entities_async.py | 33 ++++----- .../tests/test_recognize_linked_entities.py | 25 ++++--- .../test_recognize_linked_entities_async.py | 27 +++++--- .../tests/test_recognize_pii_entities.py | 30 ++++---- .../test_recognize_pii_entities_async.py | 28 +++----- 16 files changed, 386 insertions(+), 107 deletions(-) diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py index 095917d87627..62c7aadb24ff 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py @@ -1728,3 +1728,71 @@ def test_analyze_works_with_v3_1(self, client): assert document_result.id == document_order[doc_idx] assert not document_result.is_error assert self.document_result_to_action_type(document_result) == action_order[action_idx] + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_analyze_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + docs = [{"id": "56", "text": ":)"}, + {"id": "0", "text": ":("}, + {"id": "19", "text": ":P"}, + {"id": "1", "text": ":D"}] + + with pytest.raises(ValueError) as e: + response = client.begin_analyze_actions( + docs, + actions=[ + RecognizeEntitiesAction(), + ExtractKeyPhrasesAction(), + RecognizePiiEntitiesAction(), + RecognizeLinkedEntitiesAction(), + AnalyzeSentimentAction() + ], + polling_interval=self._interval(), + ).result() + assert str(e.value) == "'begin_analyze_actions' is only available for API version v3.1 and up." + + @TextAnalyticsPreparer() + @TextAnalyticsCustomPreparer() + def test_analyze_multiapi_validate_v3_1(self, **kwargs): + textanalytics_custom_text_endpoint = kwargs.pop("textanalytics_custom_text_endpoint") + textanalytics_custom_text_key = kwargs.pop("textanalytics_custom_text_key") + textanalytics_single_category_classify_project_name = kwargs.pop("textanalytics_single_category_classify_project_name") + textanalytics_single_category_classify_deployment_name = kwargs.pop("textanalytics_single_category_classify_deployment_name") + textanalytics_multi_category_classify_project_name = kwargs.pop("textanalytics_multi_category_classify_project_name") + textanalytics_multi_category_classify_deployment_name = kwargs.pop("textanalytics_multi_category_classify_deployment_name") + textanalytics_custom_entities_project_name = kwargs.pop("textanalytics_custom_entities_project_name") + textanalytics_custom_entities_deployment_name = kwargs.pop("textanalytics_custom_entities_deployment_name") + + client = TextAnalyticsClient(textanalytics_custom_text_endpoint, AzureKeyCredential(textanalytics_custom_text_key), api_version="v3.1") + + docs = [{"id": "56", "text": ":)"}, + {"id": "0", "text": ":("}, + {"id": "19", "text": ":P"}, + {"id": "1", "text": ":D"}] + version_supported = "2022-03-01-preview" + with pytest.raises(ValueError) as e: + response = client.begin_analyze_actions( + docs, + actions=[ + SingleCategoryClassifyAction( + project_name=textanalytics_single_category_classify_project_name, + deployment_name=textanalytics_single_category_classify_deployment_name + ), + MultiCategoryClassifyAction( + project_name=textanalytics_multi_category_classify_project_name, + deployment_name=textanalytics_multi_category_classify_deployment_name + ), + RecognizeCustomEntitiesAction( + project_name=textanalytics_custom_entities_project_name, + deployment_name=textanalytics_custom_entities_deployment_name + ), + ExtractSummaryAction() + ], + polling_interval=self._interval(), + ).result() + assert str(e.value) == f"'ExtractSummaryAction' is only available for API version {version_supported} and " \ + f"up.\n'RecognizeCustomEntitiesAction' is only available for API version " \ + f"{version_supported} and up.\n'SingleCategoryClassifyAction' is only available " \ + f"for API version {version_supported} and up.\n'MultiCategoryClassifyAction' is " \ + f"only available for API version {version_supported} and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py index 0c35cee52f80..c81039410f4b 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py @@ -1842,3 +1842,72 @@ async def test_analyze_works_with_v3_1(self, client): assert document_result.id == document_order[doc_idx] assert not document_result.is_error assert self.document_result_to_action_type(document_result) == action_order[action_idx] + + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_analyze_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + docs = [{"id": "56", "text": ":)"}, + {"id": "0", "text": ":("}, + {"id": "19", "text": ":P"}, + {"id": "1", "text": ":D"}] + + with pytest.raises(ValueError) as e: + response = await (await client.begin_analyze_actions( + docs, + actions=[ + RecognizeEntitiesAction(), + ExtractKeyPhrasesAction(), + RecognizePiiEntitiesAction(), + RecognizeLinkedEntitiesAction(), + AnalyzeSentimentAction() + ], + polling_interval=self._interval(), + )).result() + assert str(e.value) == "'begin_analyze_actions' is only available for API version v3.1 and up." + + @TextAnalyticsPreparer() + @TextAnalyticsCustomPreparer() + async def test_analyze_multiapi_validate_v3_1(self, **kwargs): + textanalytics_custom_text_endpoint = kwargs.pop("textanalytics_custom_text_endpoint") + textanalytics_custom_text_key = kwargs.pop("textanalytics_custom_text_key") + textanalytics_single_category_classify_project_name = kwargs.pop("textanalytics_single_category_classify_project_name") + textanalytics_single_category_classify_deployment_name = kwargs.pop("textanalytics_single_category_classify_deployment_name") + textanalytics_multi_category_classify_project_name = kwargs.pop("textanalytics_multi_category_classify_project_name") + textanalytics_multi_category_classify_deployment_name = kwargs.pop("textanalytics_multi_category_classify_deployment_name") + textanalytics_custom_entities_project_name = kwargs.pop("textanalytics_custom_entities_project_name") + textanalytics_custom_entities_deployment_name = kwargs.pop("textanalytics_custom_entities_deployment_name") + + client = TextAnalyticsClient(textanalytics_custom_text_endpoint, AzureKeyCredential(textanalytics_custom_text_key), api_version="v3.1") + + docs = [{"id": "56", "text": ":)"}, + {"id": "0", "text": ":("}, + {"id": "19", "text": ":P"}, + {"id": "1", "text": ":D"}] + version_supported = "2022-03-01-preview" + with pytest.raises(ValueError) as e: + response = await (await client.begin_analyze_actions( + docs, + actions=[ + SingleCategoryClassifyAction( + project_name=textanalytics_single_category_classify_project_name, + deployment_name=textanalytics_single_category_classify_deployment_name + ), + MultiCategoryClassifyAction( + project_name=textanalytics_multi_category_classify_project_name, + deployment_name=textanalytics_multi_category_classify_deployment_name + ), + RecognizeCustomEntitiesAction( + project_name=textanalytics_custom_entities_project_name, + deployment_name=textanalytics_custom_entities_deployment_name + ), + ExtractSummaryAction() + ], + polling_interval=self._interval(), + )).result() + assert str(e.value) == f"'ExtractSummaryAction' is only available for API version {version_supported} and " \ + f"up.\n'RecognizeCustomEntitiesAction' is only available for API version " \ + f"{version_supported} and up.\n'SingleCategoryClassifyAction' is only available " \ + f"for API version {version_supported} and up.\n'MultiCategoryClassifyAction' is " \ + f"only available for API version {version_supported} and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare.py index a6db68bf7f3e..947f6f22c0c1 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare.py @@ -549,3 +549,42 @@ def test_poller_metadata(self, client): assert isinstance(poller.expires_on, datetime.datetime) assert isinstance(poller.last_modified_on, datetime.datetime) assert poller.id + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_healthcare_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + poller = client.begin_analyze_healthcare_entities( + documents=[ + {"id": "1", + "text": "Baby not likely to have Meningitis. In case of fever in the mother, consider Penicillin for the baby too."}, + {"id": "2", "text": "patients must have histologically confirmed NHL"}, + {"id": "3", "text": ""}, + {"id": "4", "text": "The patient was diagnosed with Parkinsons Disease (PD)"} + ], + show_stats=True, + polling_interval=self._interval(), + ) + assert str(e.value) == "'begin_analyze_healthcare_entities' is only available for API version v3.1 and up." + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.1"}) + def test_healthcare_multiapi_validate_v3_1(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + poller = client.begin_analyze_healthcare_entities( + documents=[ + {"id": "1", + "text": "Baby not likely to have Meningitis. In case of fever in the mother, consider Penicillin for the baby too."}, + {"id": "2", "text": "patients must have histologically confirmed NHL"}, + {"id": "3", "text": ""}, + {"id": "4", "text": "The patient was diagnosed with Parkinsons Disease (PD)"} + ], + display_name="this won't work", + show_stats=True, + polling_interval=self._interval(), + ) + assert str(e.value) == "'display_name' is only available for API version 2022-03-01-preview and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare_async.py index 4f6fb17d8566..5217417a37af 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_healthcare_async.py @@ -578,3 +578,42 @@ async def test_poller_metadata(self, client): assert isinstance(poller.expires_on, datetime.datetime) assert isinstance(poller.last_modified_on, datetime.datetime) assert poller.id + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_healthcare_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + poller = await client.begin_analyze_healthcare_entities( + documents=[ + {"id": "1", + "text": "Baby not likely to have Meningitis. In case of fever in the mother, consider Penicillin for the baby too."}, + {"id": "2", "text": "patients must have histologically confirmed NHL"}, + {"id": "3", "text": ""}, + {"id": "4", "text": "The patient was diagnosed with Parkinsons Disease (PD)"} + ], + show_stats=True, + polling_interval=self._interval(), + ) + assert str(e.value) == "'begin_analyze_healthcare_entities' is only available for API version v3.1 and up." + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.1"}) + async def test_healthcare_multiapi_validate_v3_1(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + poller = await client.begin_analyze_healthcare_entities( + documents=[ + {"id": "1", + "text": "Baby not likely to have Meningitis. In case of fever in the mother, consider Penicillin for the baby too."}, + {"id": "2", "text": "patients must have histologically confirmed NHL"}, + {"id": "3", "text": ""}, + {"id": "4", "text": "The patient was diagnosed with Parkinsons Disease (PD)"} + ], + display_name="this won't work", + show_stats=True, + polling_interval=self._interval(), + ) + assert str(e.value) == "'display_name' is only available for API version 2022-03-01-preview and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment.py index ed05dab043c1..b5df8cb5977e 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment.py @@ -760,22 +760,6 @@ def test_no_offset_v3_sentence_sentiment(self, client): assert sentences[0].offset is None assert sentences[1].offset is None - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_not_fail_v3(self, client): - # make sure that the addition of the string_index_type kwarg for v3.1-preview.1 doesn't - # cause v3.0 calls to fail - client.analyze_sentiment(["please don't fail"]) - - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - client.analyze_sentiment(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy @@ -849,3 +833,24 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_sentiment_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = client.analyze_sentiment(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.analyze_sentiment(["I'm tired"], show_opinion_mining=True) + assert str(e.value) == "'show_opinion_mining' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.analyze_sentiment(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.analyze_sentiment(["I'm tired"], show_opinion_mining=True, disable_service_logs=True, string_index_type="UnicodeCodePoint") + assert str(e.value) == "'show_opinion_mining' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n'string_index_type' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment_async.py index d164ca5ee3d2..b9e59c68d621 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_sentiment_async.py @@ -771,14 +771,6 @@ async def test_string_index_type_not_fail_v3(self, client): # cause v3.0 calls to fail await client.analyze_sentiment(["please don't fail"]) - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - await client.analyze_sentiment(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy_async @@ -852,3 +844,24 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_sentiment_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.analyze_sentiment(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.analyze_sentiment(["I'm tired"], show_opinion_mining=True) + assert str(e.value) == "'show_opinion_mining' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.analyze_sentiment(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.analyze_sentiment(["I'm tired"], show_opinion_mining=True, disable_service_logs=True, string_index_type="UnicodeCodePoint") + assert str(e.value) == "'show_opinion_mining' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n'string_index_type' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language.py index ece4c47a1b8d..c922b2f7a161 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language.py @@ -657,3 +657,12 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_language_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.detect_language(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language_async.py index eab966ccaf66..b922cff3a8ac 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_detect_language_async.py @@ -658,3 +658,12 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_language_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.detect_language(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases.py index b9fad233841a..436a25b4a1be 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases.py @@ -571,3 +571,12 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_key_phrases_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = client.extract_key_phrases(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases_async.py index 439ab2cdaf4d..7560ab117e69 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_extract_key_phrases_async.py @@ -573,3 +573,12 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_key_phrases_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.extract_key_phrases(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities.py index c957292bb72a..2d03c3c768ab 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities.py @@ -598,22 +598,6 @@ def test_no_offset_v3_categorized_entities(self, client): assert entities[1].offset is None assert entities[2].offset is None - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_not_fail_v3(self, client): - # make sure that the addition of the string_index_type kwarg for v3.1-preview doesn't - # cause v3.0 calls to fail - client.recognize_entities(["please don't fail"]) - - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - client.recognize_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy @@ -687,3 +671,20 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_entities_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = client.recognize_entities(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.recognize_entities(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.recognize_entities(["I'm tired"], string_index_type="UnicodeCodePoint", disable_service_logs=True) + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities_async.py index be420908a2e3..2bab2f1027f2 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_entities_async.py @@ -605,22 +605,6 @@ async def test_no_offset_v3_categorized_entities(self, client): assert entities[1].offset is None assert entities[2].offset is None - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_string_index_type_not_fail_v3(self, client): - # make sure that the addition of the string_index_type kwarg for v3.1-preview doesn't - # cause v3.0 calls to fail - await client.recognize_entities(["please don't fail"]) - - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - await client.recognize_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy_async @@ -694,3 +678,20 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_entities_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.recognize_entities(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.recognize_entities(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.recognize_entities(["I'm tired"], string_index_type="UnicodeCodePoint", disable_service_logs=True) + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities.py index 0ef0ca2c89a6..40f628bf23a7 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities.py @@ -622,14 +622,6 @@ def test_bing_id(self, client): for entity in doc.entities: assert entity.bing_entity_search_api_id # this checks if it's None and if it's empty - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - client.recognize_linked_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy @@ -703,3 +695,20 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_linked_entities_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = client.recognize_linked_entities(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.recognize_linked_entities(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = client.recognize_linked_entities(["I'm tired"], string_index_type="UnicodeCodePoint", disable_service_logs=True) + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities_async.py index 929795429684..e246c017aac0 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_linked_entities_async.py @@ -646,14 +646,6 @@ async def test_bing_id(self, client): for entity in doc.entities: assert entity.bing_entity_search_api_id # this checks if it's None and if it's empty - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - await client.recognize_linked_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy_async @@ -729,3 +721,22 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_linked_entities_multiapi_validate_args_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + res = await client.recognize_linked_entities(["I'm tired"], string_index_type="UnicodeCodePoint") + assert str(e.value) == "'string_index_type' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.recognize_linked_entities(["I'm tired"], disable_service_logs=True) + assert str(e.value) == "'disable_service_logs' is only available for API version v3.1 and up.\n" + + with pytest.raises(ValueError) as e: + res = await client.recognize_linked_entities(["I'm tired"], string_index_type="UnicodeCodePoint", + disable_service_logs=True) + assert str( + e.value) == "'string_index_type' is only available for API version v3.1 and up.\n'disable_service_logs' is only available for API version v3.1 and up.\n" diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities.py index a3778ef33628..ac641b5114ff 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities.py @@ -600,15 +600,6 @@ def callback(response): raw_response_hook=callback ) - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_recognize_pii_entities_v3(self, client): - with pytest.raises(ValueError) as excinfo: - client.recognize_pii_entities(["this should fail"]) - - assert "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer() @recorded_by_proxy @@ -667,14 +658,6 @@ def test_categories_filter_with_domain_filter(self, client): entity = result[0].entities[0] assert entity.category == PiiEntityCategory.US_SOCIAL_SECURITY_NUMBER.value - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy - def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - client.recognize_pii_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy @@ -747,4 +730,15 @@ def callback(resp): documents=["Test for logging disable"], disable_service_logs=True, raw_response_hook=callback, - ) \ No newline at end of file + ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + def test_pii_entities_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + client.recognize_pii_entities( + documents=["Test"] + ) + assert str(e.value) == "'recognize_pii_entities' is only available for API version v3.1 and up." diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities_async.py index 62127cbb1ec3..2b1cc846d667 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_recognize_pii_entities_async.py @@ -602,15 +602,6 @@ def callback(response): raw_response_hook=callback ) - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_recognize_pii_entities_v3(self, client): - with pytest.raises(ValueError) as excinfo: - await client.recognize_pii_entities(["this should fail"]) - - assert "'recognize_pii_entities' endpoint is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer() @recorded_by_proxy_async @@ -668,14 +659,6 @@ async def test_categories_filter_with_domain_filter(self, client): entity = result[0].entities[0] assert entity.category == PiiEntityCategory.US_SOCIAL_SECURITY_NUMBER.value - @TextAnalyticsPreparer() - @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_0}) - @recorded_by_proxy_async - async def test_string_index_type_explicit_fails_v3(self, client): - with pytest.raises(ValueError) as excinfo: - await client.recognize_pii_entities(["this should fail"], string_index_type="UnicodeCodePoint") - assert "'string_index_type' is only available for API version V3_1 and up" in str(excinfo.value) - @TextAnalyticsPreparer() @TextAnalyticsClientPreparer(client_kwargs={"api_version": TextAnalyticsApiVersion.V3_1}) @recorded_by_proxy_async @@ -749,3 +732,14 @@ def callback(resp): disable_service_logs=True, raw_response_hook=callback, ) + + @TextAnalyticsPreparer() + @TextAnalyticsClientPreparer(client_kwargs={"api_version": "v3.0"}) + async def test_pii_entities_multiapi_validate_v3_0(self, **kwargs): + client = kwargs.pop("client") + + with pytest.raises(ValueError) as e: + await client.recognize_pii_entities( + documents=["Test"] + ) + assert str(e.value) == "'recognize_pii_entities' is only available for API version v3.1 and up."