Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@ def is_language_api(api_version):
"""Language API is date-based
"""
return re.search(r'\d{4}-\d{2}-\d{2}', api_version)


def string_index_type_compatibility(string_index_type):
"""Language API changed this string_index_type option to plural.
Convert singular to plural for language API
"""
if string_index_type == "TextElement_v8":
return "TextElements_v8"
return string_index_type
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,7 @@
from ._generated.v3_0 import models as _v3_0_models
from ._generated.v3_1 import models as _v3_1_models
from ._generated.v2022_03_01_preview import models as _v2022_03_01_preview_models
from ._check import is_language_api


def string_index_type_compatibility(string_index_type):
"""Language API changed this string_index_type option to plural.
Convert singular to plural for language API
"""
if string_index_type == "TextElement_v8":
return "TextElements_v8"
return string_index_type
from ._check import is_language_api, string_index_type_compatibility


def _get_indices(relation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
TextDocumentInput,
_AnalyzeActionsType,
)
from ._check import is_language_api


def _validate_input(documents, hint, whole_input_hint):
Expand Down Expand Up @@ -99,25 +98,3 @@ def _determine_action_type(action): # pylint: disable=too-many-return-statement
if action.__class__.__name__ == "CustomMultiLabelClassificationLROTask":
return _AnalyzeActionsType.MULTI_CATEGORY_CLASSIFY
return _AnalyzeActionsType.EXTRACT_KEY_PHRASES


def _check_string_index_type_arg(
string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"
):
string_index_type = None

if api_version == "v3.0":
if string_index_type_arg is not None:
raise ValueError(
"'string_index_type' is only available for API version V3_1 and up"
)
elif is_language_api(api_version) and string_index_type_arg == "TextElement_v8":
return "TextElements_v8"
else:
if string_index_type_arg is None:
string_index_type = string_index_type_default

else:
string_index_type = string_index_type_arg

return string_index_type
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import HttpResponseError
from azure.core.credentials import AzureKeyCredential
from ._base_client import TextAnalyticsClientBase, TextAnalyticsApiVersion
from ._base_client import TextAnalyticsClientBase
from ._lro import AnalyzeActionsLROPoller, AnalyzeHealthcareEntitiesLROPoller
from ._request_handlers import (
_validate_input,
_determine_action_type,
_check_string_index_type_arg,
)
from ._validate import validate_multiapi_args, check_for_unsupported_actions_types
from ._version import DEFAULT_API_VERSION
from ._response_handlers import (
process_http_response_error,
Expand Down Expand Up @@ -67,7 +67,7 @@
MultiCategoryClassifyResult,
_AnalyzeActionsType,
)
from ._check import is_language_api
from ._check import is_language_api, string_index_type_compatibility

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -132,6 +132,10 @@ def __init__(
)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.0",
args_mapping={"v3.1": ["disable_service_logs"]}
)
def detect_language(
self,
documents: Union[List[str], List[DetectLanguageInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -228,6 +232,10 @@ def detect_language(
return process_http_response_error(error)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.0",
args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]}
)
def recognize_entities(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -294,11 +302,7 @@ def recognize_entities(
docs = _validate_input(documents, "language", language)
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", None)
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
disable_service_logs = kwargs.pop("disable_service_logs", None)

try:
Expand All @@ -310,7 +314,7 @@ def recognize_entities(
parameters=models.EntitiesTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type
string_index_type=string_index_type_compatibility(string_index_type)
)
),
show_stats=show_stats,
Expand All @@ -332,6 +336,9 @@ def recognize_entities(
return process_http_response_error(error)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.1"
)
def recognize_pii_entities(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -408,11 +415,7 @@ def recognize_pii_entities(
show_stats = kwargs.pop("show_stats", None)
domain_filter = kwargs.pop("domain_filter", None)
categories_filter = kwargs.pop("categories_filter", None)
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
disable_service_logs = kwargs.pop("disable_service_logs", None)

try:
Expand All @@ -426,7 +429,7 @@ def recognize_pii_entities(
model_version=model_version,
domain=domain_filter,
pii_categories=categories_filter,
string_index_type=string_index_type
string_index_type=string_index_type_compatibility(string_index_type)
)
),
show_stats=show_stats,
Expand All @@ -446,19 +449,14 @@ def recognize_pii_entities(
cls=kwargs.pop("cls", pii_entities_result),
**kwargs
)
except ValueError as error:
if (
"API version v3.0 does not have operation 'entities_recognition_pii'"
in str(error)
):
raise ValueError(
"'recognize_pii_entities' endpoint is only available for API version V3_1 and up"
) from error
raise error
except HttpResponseError as error:
return process_http_response_error(error)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.0",
args_mapping={"v3.1": ["string_index_type", "disable_service_logs"]}
)
def recognize_linked_entities(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -527,11 +525,7 @@ def recognize_linked_entities(
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)

try:
if is_language_api(self._api_version):
Expand All @@ -542,7 +536,7 @@ def recognize_linked_entities(
parameters=models.EntityLinkingTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type
string_index_type=string_index_type_compatibility(string_index_type)
)
),
show_stats=show_stats,
Expand Down Expand Up @@ -581,6 +575,10 @@ def _healthcare_result_callback(
)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.1",
args_mapping={"2022-03-01-preview": ["display_name"]}
)
def begin_analyze_healthcare_entities(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -658,11 +656,7 @@ def begin_analyze_healthcare_entities(
show_stats = kwargs.pop("show_stats", None)
polling_interval = kwargs.pop("polling_interval", 5)
continuation_token = kwargs.pop("continuation_token", None)
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)
disable_service_logs = kwargs.pop("disable_service_logs", None)
display_name = kwargs.pop("display_name", None)

Expand Down Expand Up @@ -710,7 +704,7 @@ def get_result_from_cont_token(initial_response, pipeline_response):
parameters=models.HealthcareTaskParameters(
model_version=model_version,
logging_opt_out=disable_service_logs,
string_index_type=string_index_type,
string_index_type=string_index_type_compatibility(string_index_type)
)
)
]
Expand Down Expand Up @@ -755,19 +749,14 @@ def get_result_from_cont_token(initial_response, pipeline_response):
continuation_token=continuation_token,
**kwargs
)

except ValueError as error:
if "API version v3.0 does not have operation 'begin_health'" in str(error):
raise ValueError(
"'begin_analyze_healthcare_entities' method is only available for API version \
V3_1 and up."
) from error
raise error

except HttpResponseError as error:
return process_http_response_error(error)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.0",
args_mapping={"v3.1": ["disable_service_logs"]}
)
def extract_key_phrases(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -862,6 +851,10 @@ def extract_key_phrases(
return process_http_response_error(error)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.0",
args_mapping={"v3.1": ["show_opinion_mining", "disable_service_logs", "string_index_type"]}
)
def analyze_sentiment(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -936,19 +929,7 @@ def analyze_sentiment(
show_stats = kwargs.pop("show_stats", None)
show_opinion_mining = kwargs.pop("show_opinion_mining", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
if show_opinion_mining is not None:
if (
self._api_version == TextAnalyticsApiVersion.V3_0
and show_opinion_mining
):
raise ValueError(
"'show_opinion_mining' is only available for API version v3.1 and up"
)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)

try:
if is_language_api(self._api_version):
Expand All @@ -959,7 +940,7 @@ def analyze_sentiment(
parameters=models.SentimentAnalysisTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type,
string_index_type=string_index_type_compatibility(string_index_type),
opinion_mining=show_opinion_mining,
)
),
Expand Down Expand Up @@ -1001,6 +982,10 @@ def _analyze_result_callback(
)

@distributed_trace
@validate_multiapi_args(
version_method_added="v3.1",
custom_wrapper=check_for_unsupported_actions_types
)
def begin_analyze_actions(
self,
documents: Union[List[str], List[TextDocumentInput], List[Dict[str, str]]],
Expand Down Expand Up @@ -1253,13 +1238,5 @@ def get_result_from_cont_token(initial_response, pipeline_response):
continuation_token=continuation_token,
**kwargs
)

except ValueError as error:
if "API version v3.0 does not have operation 'begin_analyze'" in str(error):
raise ValueError(
"'begin_analyze_actions' endpoint is only available for API version V3_1 and up"
) from error
raise error

except HttpResponseError as error:
return process_http_response_error(error)
Loading