Skip to content
Merged
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
Expand All @@ -43,7 +42,6 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>());
Expand Down
33 changes: 0 additions & 33 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1089,39 +1089,6 @@ The bounding box coordinates corresponding to the selected indices can then be o
updateOutputShape(ctx, 0, input_shape);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "X", "Strings to normalize", "T")
.Output(0, "Y", "Normalized strings", "T")
.TypeConstraint(
"T",
{"tensor(string)"},
"Input/Output is a string tensor")
.Attr(
"casechangeaction",
"string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"",
AttributeProto::STRING)
.Attr(
"is_case_sensitive",
"Boolean. Whether the identification of stop words in X is case-sensitive.",
AttributeProto::INT)
.Attr(
"stopwords",
"List of stop words",
AttributeProto::STRINGS,
OPTIONAL)
.Attr(
"locale",
"Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US",
AttributeProto::STRING,
OPTIONAL)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING);
})
.SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");

ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where);

// Opset 10
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer);

void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Clip)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Elu)>());
Expand Down Expand Up @@ -529,6 +532,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, NonZero)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Where)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, Where)>());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this empty line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I need to add a comment // Opset 10

// Opset 10
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, StringNormalizer)>());
}

// Forward declarations of ml op kernels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
#include <unordered_set>

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
ONNX_CPU_OPERATOR_KERNEL(
StringNormalizer,
1,
string,
10,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
contrib::StringNormalizer);
StringNormalizer);

namespace string_normalizer {
const std::string conv_error("Conversion Error");
Expand Down Expand Up @@ -157,29 +155,29 @@ using namespace string_normalizer;

StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info),
is_case_sensitive_(true),
casechangeaction_(NONE),
case_change_action_(NONE),
compare_caseaction_(NONE) {
int64_t iscasesensitive = 0;
Status status = info.GetAttr("is_case_sensitive", &iscasesensitive);
ORT_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set");
is_case_sensitive_ = iscasesensitive != 0;

std::string casechangeaction;
status = info.GetAttr("casechangeaction", &casechangeaction);
ORT_ENFORCE(status.IsOK(), "attribute caseaction is not set");
if (casechangeaction == "LOWER") {
casechangeaction_ = LOWER;
} else if (casechangeaction == "UPPER") {
casechangeaction_ = UPPER;
} else if (casechangeaction == "NONE") {
casechangeaction_ = NONE;
std::string case_change_action;
status = info.GetAttr("case_change_action", &case_change_action);
ORT_ENFORCE(status.IsOK(), "attribute case_change_action is not set");
if (case_change_action == "LOWER") {
case_change_action_ = LOWER;
} else if (case_change_action == "UPPER") {
case_change_action_ = UPPER;
} else if (case_change_action == "NONE") {
case_change_action_ = NONE;
} else {
ORT_ENFORCE(false, "attribute casechangeaction has invalid value");
ORT_ENFORCE(false, "attribute case_change_action has invalid value");
}

if (!is_case_sensitive_) {
// Convert stop words to a case which can help us preserve the case of filtered strings
compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER;
compare_caseaction_ = (case_change_action_ == UPPER) ? UPPER : LOWER;
}

locale_name_ = info.GetAttrOrDefault("locale", default_locale);
Expand Down Expand Up @@ -248,10 +246,10 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const {
++first;
}
status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter,
N, filtered_strings.size(), casechangeaction_);
N, filtered_strings.size(), case_change_action_);
} else {
// Nothing to filter. Copy input to output and change case if needed
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, case_change_action_);
}
} else {
if (!wstopwords_.empty()) {
Expand All @@ -273,15 +271,15 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const {
}
locale.ChangeCase(compare_caseaction_, wstr);
if (0 == wstopwords_.count(wstr)) {
if (casechangeaction_ == NONE) {
if (case_change_action_ == NONE) {
filtered_orignal_strings.push_back(std::cref(s));
} else {
filtered_cased_strings.push_back(converter.to_bytes(wstr));
}
}
++first;
}
if (casechangeaction_ == NONE) {
if (case_change_action_ == NONE) {
status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter,
N, filtered_orignal_strings.size(), NONE);
} else {
Expand All @@ -290,10 +288,9 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const {
}
} else {
// Nothing to filter. Copy input to output and change case if needed
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, case_change_action_);
}
}
return status;
}
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <unordered_set>

namespace onnxruntime {
namespace contrib {

class StringNormalizer : public OpKernel {
public:
Expand All @@ -27,13 +26,12 @@ class StringNormalizer : public OpKernel {

private:
bool is_case_sensitive_;
CaseAction casechangeaction_;
CaseAction case_change_action_;
CaseAction compare_caseaction_; // used for case-insensitive compare
std::string locale_name_;
// Either if these are populated but not both
std::unordered_set<std::string> stopwords_;
std::unordered_set<std::wstring> wstopwords_;
};

} // namespace contrib
} // namespace onnxruntime
12 changes: 0 additions & 12 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,6 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) {
{"atanh_example", "opset 9 not supported yet"},
{"scan_sum", "opset 9 not supported yet"},
{"shrink", "opset 9 not supported yet"},
{"strnormalizer_export_monday_casesensintive_lower", "opset 10 not supported yet"},
{"strnormalizer_export_monday_casesensintive_nochangecase", "opset 10 not supported yet"},
{"strnormalizer_export_monday_casesensintive_upper", "opset 10 not supported yet"},
{"strnormalizer_export_monday_empty_output", "opset 10 not supported yet"},
{"strnormalizer_export_monday_insensintive_upper_twodim", "opset 10 not supported yet"},
{"strnormalizer_nostopwords_nochangecase", "opset 10 not supported yet"},
{"strnorm_model_monday_casesensintive_lower", "opset 10 not supported yet"},
{"strnorm_model_monday_casesensintive_nochangecase", "opset 10 not supported yet"},
{"strnorm_model_monday_casesensintive_upper", "opset 10 not supported yet"},
{"strnorm_model_monday_empty_output", "opset 10 not supported yet"},
{"strnorm_model_monday_insensintive_upper_twodim", "opset 10 not supported yet"},
{"strnorm_model_nostopwords_nochangecase", "opset 10 not supported yet"},
{"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"},
{"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"},
{"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ namespace onnxruntime {
namespace test {

namespace str_normalizer_test {
constexpr const char* domain = onnxruntime::kMSDomain;
const int opset_ver = 1;
constexpr const char* domain = kOnnxDomain;
const int opset_ver = 10;

#ifdef _MSC_VER
const std::string test_locale("en-US");
#else
const std::string test_locale("en_US.UTF-8");
#endif

void InitTestAttr(OpTester& test, const std::string& casechangeaction,
bool iscasesensitive,
void InitTestAttr(OpTester& test, const std::string& case_change_action,
bool is_case_sensitive,
const std::vector<std::string>& stopwords,
const std::string& locale) {
test.AddAttribute("casechangeaction", casechangeaction);
test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive});
if (!case_change_action.empty()) {
test.AddAttribute("case_change_action", case_change_action);
}
test.AddAttribute("is_case_sensitive", int64_t{is_case_sensitive});
if (!stopwords.empty()) {
test.AddAttribute("stopwords", stopwords);
}
Expand All @@ -36,27 +38,12 @@ void InitTestAttr(OpTester& test, const std::string& casechangeaction,
using namespace str_normalizer_test;

TEST(ContribOpTest, StringNormalizerTest) {
// Test wrong 2 dimensions
// - casesensitive approach
// - no stopwords.
// - No change case action
{
OpTester test("StringNormalizer", opset_ver, domain);
InitTestAttr(test, "NONE", true, {}, test_locale);
std::vector<int64_t> dims{2, 2};
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")};
test.AddInput<std::string>("T", dims, input);
std::vector<std::string> output(input); // do the same for now
test.AddOutput<std::string>("Y", dims, output);

test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed");
}
// - casesensitive approach
// - no stopwords.
// - No change case action
// - No change case action, expecting default to take over
{
OpTester test("StringNormalizer", opset_ver, domain);
InitTestAttr(test, "NONE", true, {}, test_locale);
InitTestAttr(test, "", true, {}, test_locale);
std::vector<int64_t> dims{4};
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"),
std::string("wednesday"), std::string("thursday")};
Expand Down
30 changes: 19 additions & 11 deletions onnxruntime/test/python/onnx_backend_test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,39 @@
import unittest
import onnx.backend.test

import numpy as np
import onnxruntime.backend as c2

pytest_plugins = 'onnx.backend.test.report',

class OrtBackendTest(onnx.backend.test.BackendTest):
class OnnxruntimeBackendTest(onnx.backend.test.BackendTest):

def __init__(self, backend, parent_module=None):
super(OrtBackendTest, self).__init__(backend, parent_module)
def __init__(self, backend, parent_module=None):
onnx.backend.test.BackendTest.__init__(self, backend, parent_module)

@classmethod
def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol):
# type: (Sequence[Any], Sequence[Any], float, float) -> None
@classmethod
def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol):
np.testing.assert_equal(len(ref_outputs), len(outputs))
for i in range(len(outputs)):
np.testing.assert_equal(ref_outputs[i].dtype, outputs[i].dtype)
if ref_outputs[i].dtype == np.object:
np.testing.assert_array_equal(ref_outputs[i], outputs[i])
else:
np.testing.assert_allclose(
ref_outputs[i],
outputs[i],
rtol=rtol,
atol=atol)

# override the rtol and atol values to match onnx_test_runner tolerances
super(OrtBackendTest, cls).assert_similar_outputs(ref_outputs, outputs, 1e-3, 1e-5)

backend_test = OrtBackendTest(c2, __name__)
backend_test = OnnxruntimeBackendTest(c2, __name__)

# Type not supported
backend_test.exclude(r'(FLOAT16)')

backend_test.exclude(r'('
'^test_cast_DOUBLE_to_FLOAT_cpu.*'
'|^test_gru_seq_length_cpu.*'
'|^test_cast_FLOAT_to_DOUBLE_cpu.*'
'|^test_cast_FLOAT_to_STRING_cpu.*'
'|^test_cast_STRING_to_FLOAT_cpu.*'
Expand Down Expand Up @@ -58,8 +68,6 @@ def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol):
'|^test_PReLU_3d_cpu.*'
'|^test_PReLU_3d_multiparam_cpu.*'
'|^test_PoissonNLLLLoss_no_reduce_cpu.*'
'|^test_strnormalizer_*.*'
'|^test_strnorm_*.*'
'|^test_Softsign_cpu.*'
'|^test_operator_add_broadcast_cpu.*'
'|^test_operator_add_size1_broadcast_cpu.*'
Expand Down